diff --git a/backend/entityservice/cache/encodings.py b/backend/entityservice/cache/encodings.py index 99cb73bd..0f23a4b5 100644 --- a/backend/entityservice/cache/encodings.py +++ b/backend/entityservice/cache/encodings.py @@ -35,15 +35,14 @@ def get_deserialized_filter(dp_id): else: logger.debug("Looking up popcounts and filename from database") with DBConn() as db: - serialized_filters_file = get_filter_metadata(db, dp_id) - + serialized_filters_file, encoding_size = get_filter_metadata(db, dp_id) mc = connect_to_object_store() logger.debug("Getting filters from object store") # Note this uses already calculated popcounts unlike # serialization.deserialize_filters() raw_data_response = mc.get_object(config.MINIO_BUCKET, serialized_filters_file) - python_filters = binary_unpack_filters(raw_data_response) + python_filters = binary_unpack_filters(raw_data_response.stream(encoding_size)) set_deserialized_filter(dp_id, python_filters) return python_filters diff --git a/backend/entityservice/database/selections.py b/backend/entityservice/database/selections.py index 8141ce12..bcb3e3e0 100644 --- a/backend/entityservice/database/selections.py +++ b/backend/entityservice/database/selections.py @@ -260,32 +260,33 @@ def get_dataprovider_id(db, update_token): return query_db(db, sql_query, [update_token], one=True)['id'] -def get_bloomingdata_column(db, dp_id, column): - assert column in {'ts', 'token', 'file', 'state', 'count'} +def get_bloomingdata_columns(db, dp_id, columns): + for column in columns: + assert column in {'ts', 'token', 'file', 'state', 'count', 'encoding_size'} sql_query = """ SELECT {} FROM bloomingdata WHERE dp = %s - """.format(column) + """.format(', '.join(columns)) result = query_db(db, sql_query, [dp_id], one=True) if result is None: raise DataProviderDeleted(dp_id) - else: - return result[column] + return [result[column] for column in columns] def get_filter_metadata(db, dp_id): """ - :return: The filename of the raw clks. + :return: The filename and the encoding size of the raw clks. """ - return get_bloomingdata_column(db, dp_id, 'file').strip() + filename, encoding_size = get_bloomingdata_columns(db, dp_id, ['file', 'encoding_size']) + return filename.strip(), encoding_size def get_number_of_hashes(db, dp_id): """ :return: The count of the uploaded encodings. """ - return get_bloomingdata_column(db, dp_id, 'count') + return get_bloomingdata_columns(db, dp_id, ['count'])[0] def get_project_schema_encoding_size(db, project_id): diff --git a/backend/entityservice/error_checking.py b/backend/entityservice/error_checking.py index 6dcdc7b1..6a120419 100644 --- a/backend/entityservice/error_checking.py +++ b/backend/entityservice/error_checking.py @@ -30,6 +30,6 @@ def check_dataproviders_encoding(project_id, encoding_size): def handle_invalid_encoding_data(project_id, dp_id): with DBConn() as conn: - filename = get_filter_metadata(conn, dp_id) + filename, _ = get_filter_metadata(conn, dp_id) update_encoding_metadata(conn, 'DELETED', dp_id, state='error') delete_minio_objects.delay([filename], project_id) diff --git a/backend/entityservice/serialization.py b/backend/entityservice/serialization.py index 571490e7..c5a558e7 100644 --- a/backend/entityservice/serialization.py +++ b/backend/entityservice/serialization.py @@ -41,11 +41,16 @@ def binary_format(encoding_size): The binary format string can be understood as: - "!" Use network byte order (big-endian). + - "I" store the entity ID as an unsigned int - "s" Store the n (e.g. 128) raw bytes of the bitarray https://docs.python.org/3/library/struct.html + + :param encoding_size: the encoding size of one filter in number of bytes, excluding the entity ID info + :return: + A Struct object which can read and write the binary format. """ - bit_packing_fmt = f"!{encoding_size}s" + bit_packing_fmt = f"!I{encoding_size}s" bit_packing_struct = struct.Struct(bit_packing_fmt) return bit_packing_struct @@ -54,31 +59,41 @@ def binary_pack_filters(filters, encoding_size): """Efficient packing of bloomfilters. :param filters: - An iterable of bytes as produced by deserialize_bytes. - + An iterable of tuples, with + - first element is the entity ID as an unsigned int + - second element is 'encoding_size' number of bytes as produced by deserialize_bytes. + :param encoding_size: the encoding size of one filter in number of bytes, excluding the entity ID info :return: An iterable of bytes. """ bit_packing_struct = binary_format(encoding_size) for hash_bytes in filters: - yield bit_packing_struct.pack(hash_bytes) + yield bit_packing_struct.pack(*hash_bytes) def binary_unpack_one(data, bit_packing_struct): - clk_bytes, = bit_packing_struct.unpack(data) - return clk_bytes + entity_id, clk_bytes, = bit_packing_struct.unpack(data) + return entity_id, clk_bytes -def binary_unpack_filters(streamable_data, max_bytes=None, encoding_size=None): +def binary_unpack_filters(data_iterable, max_bytes=None, encoding_size=None): + """ + Unpack filters that were packed with the 'binary_pack_filters' method. + + :param data_iterable: an iterable of binary packed filters. + :param max_bytes: if present, only read up to 'max_bytes' bytes. + :param encoding_size: the encoding size of one filter in number of bytes, excluding the entity ID info + :return: list of filters with their corresponding entity IDs as a list of tuples. + """ assert encoding_size is not None bit_packed_element = binary_format(encoding_size) bit_packed_element_size = bit_packed_element.size filters = [] bytes_consumed = 0 - logger.info(f"Unpacking stream of encodings with size {encoding_size} - packed as {bit_packed_element_size}") - for raw_bytes in streamable_data.stream(bit_packed_element_size): + logger.info(f"Iterating over encodings of size {encoding_size} - packed as {bit_packed_element_size}") + for raw_bytes in data_iterable: filters.append(binary_unpack_one(raw_bytes, bit_packed_element)) bytes_consumed += bit_packed_element_size @@ -185,6 +200,6 @@ def get_chunk_from_object_store(chunk_info, encoding_size=128): bit_packed_element_size * chunk_range_start, chunk_bytes) - chunk_data = binary_unpack_filters(chunk_stream, chunk_bytes, encoding_size) + chunk_data = binary_unpack_filters(chunk_stream.stream(bit_packed_element_size), chunk_bytes, encoding_size) return chunk_data, chunk_length diff --git a/backend/entityservice/tasks/comparing.py b/backend/entityservice/tasks/comparing.py index 7a3b5c81..f13eba6a 100644 --- a/backend/entityservice/tasks/comparing.py +++ b/backend/entityservice/tasks/comparing.py @@ -61,9 +61,8 @@ def create_comparison_jobs(project_id, run_id, parent_span=None): current_span.log_kv({"event": 'get-dataset-sizes'}) filters_object_filenames = tuple( - get_filter_metadata(conn, dp_id) for dp_id in dp_ids) + get_filter_metadata(conn, dp_id)[0] for dp_id in dp_ids) current_span.log_kv({"event": 'get-metadata'}) - log.debug("Chunking computation task") chunk_infos = tuple(anonlink.concurrency.split_to_chunks( @@ -125,10 +124,13 @@ def compute_filter_similarity(chunk_info, project_id, run_id, threshold, encodin t0 = time.time() log.debug("Fetching and deserializing chunk of filters for dataprovider 1") chunk_dp1, chunk_dp1_size = get_chunk_from_object_store(chunk_info_dp1, encoding_size) - + #TODO: use the entity ids! + entity_ids_dp1, chunk_dp1 = zip(*chunk_dp1) t1 = time.time() log.debug("Fetching and deserializing chunk of filters for dataprovider 2") chunk_dp2, chunk_dp2_size = get_chunk_from_object_store(chunk_info_dp2, encoding_size) + # TODO: use the entity ids! + entity_ids_dp2, chunk_dp2 = zip(*chunk_dp2) t2 = time.time() span.log_kv({'event': 'chunks are fetched and deserialized'}) log.debug("Calculating filter similarity") diff --git a/backend/entityservice/tasks/encoding_uploading.py b/backend/entityservice/tasks/encoding_uploading.py index 56469109..0820c304 100644 --- a/backend/entityservice/tasks/encoding_uploading.py +++ b/backend/entityservice/tasks/encoding_uploading.py @@ -41,14 +41,13 @@ def handle_raw_upload(project_id, dp_id, receipt_token, parent_span=None): uploaded_encoding_size = len(first_hash_bytes) def filter_generator(): - log.debug("Deserializing json filters") i = 0 - yield first_hash_bytes + yield i, first_hash_bytes for i, line in enumerate(text_stream, start=1): hash_bytes = deserialize_bytes(line) if len(hash_bytes) != uploaded_encoding_size: raise ValueError("Encodings were not all the same size") - yield hash_bytes + yield i, hash_bytes log.info(f"Processed {i + 1} hashes") diff --git a/backend/entityservice/tests/test_project_uploads.py b/backend/entityservice/tests/test_project_uploads.py index 68b3107d..99a45acb 100644 --- a/backend/entityservice/tests/test_project_uploads.py +++ b/backend/entityservice/tests/test_project_uploads.py @@ -2,13 +2,12 @@ import os import pytest -from entityservice.serialization import binary_pack_filters from entityservice.tests.config import url from entityservice.tests.util import ( create_project_upload_data, create_project_upload_fake_data, generate_clks, generate_json_serialized_clks, get_expected_number_parties, get_run_result, post_run, - upload_binary_data, upload_binary_data_from_file) + upload_binary_data, upload_binary_data_from_file, binary_pack_for_upload) def test_project_single_party_data_uploaded(requests, valid_project_params): @@ -73,13 +72,13 @@ def test_project_binary_data_upload_with_different_encoded_size( **valid_project_params }).json() - common = next(binary_pack_filters(generate_clks(1, encoding_size), + common = next(binary_pack_for_upload(generate_clks(1, encoding_size), encoding_size)) data = [] for i in range(expected_number_parties): generated_clks = generate_clks(499, encoding_size) - packed_clks = binary_pack_filters(generated_clks, encoding_size) + packed_clks = binary_pack_for_upload(generated_clks, encoding_size) packed_joined = b''.join(packed_clks) packed_with_common = ( packed_joined + common if i == 0 else common + packed_joined) diff --git a/backend/entityservice/tests/test_results_correctness_multiparty.py b/backend/entityservice/tests/test_results_correctness_multiparty.py index 856ae187..6a69303b 100644 --- a/backend/entityservice/tests/test_results_correctness_multiparty.py +++ b/backend/entityservice/tests/test_results_correctness_multiparty.py @@ -3,9 +3,8 @@ import anonlink -from entityservice.serialization import binary_pack_filters from entityservice.tests.util import ( - create_project_upload_data, delete_project, get_run_result, post_run) + create_project_upload_data, delete_project, get_run_result, post_run, binary_pack_for_upload) DATA_FILENAME = 'test-multiparty-results-correctness-data.pkl' DATA_PATH = pathlib.Path(__file__).parent / 'testdata' / DATA_FILENAME @@ -29,7 +28,7 @@ def test_groups_correctness(requests): filter_size = len(filters[0][0]) assert all(len(filter_) == filter_size for dataset in filters for filter_ in dataset) - packed_filters = [b''.join(binary_pack_filters(f, filter_size)) + packed_filters = [b''.join(binary_pack_for_upload(f, filter_size)) for f in filters] project_data, _ = create_project_upload_data( requests, packed_filters, result_type='groups', diff --git a/backend/entityservice/tests/test_serialization.py b/backend/entityservice/tests/test_serialization.py index 7d37b904..f08c9505 100644 --- a/backend/entityservice/tests/test_serialization.py +++ b/backend/entityservice/tests/test_serialization.py @@ -1,13 +1,13 @@ import io -import unittest -import random - import json +import random +import unittest from array import array import anonlink -from entityservice.serialization import deserialize_bytes, generate_scores +from entityservice.serialization import deserialize_bytes, generate_scores, binary_pack_filters, \ + binary_unpack_filters, binary_unpack_one, binary_format from entityservice.tests.util import serialize_bytes, generate_bytes @@ -69,6 +69,23 @@ def test_sims_to_json_empty(self): self.assertIn('similarity_scores', json_obj) assert len(json_obj["similarity_scores"]) == 0 + def test_binary_pack_filters(self): + encoding_size = 128 + filters = [(random.randint(0, 2 ** 32 - 1), generate_bytes(encoding_size)) for _ in range(10)] + packed_filters = binary_pack_filters(filters, encoding_size) + bin_format = binary_format(encoding_size) + for filter, packed_filter in zip(filters, packed_filters): + assert len(packed_filter) == encoding_size + 4 + unpacked = binary_unpack_one(packed_filter, bin_format) + assert filter == unpacked + + def test_binary_unpack_filters(self): + encoding_size = 128 + filters = [(random.randint(0, 2 ** 32 - 1), generate_bytes(encoding_size)) for _ in range(10)] + laundered_filters = binary_unpack_filters(binary_pack_filters(filters, encoding_size), + encoding_size=encoding_size) + assert filters == laundered_filters + if __name__ == "__main__": unittest.main() diff --git a/backend/entityservice/tests/util.py b/backend/entityservice/tests/util.py index 8d9a3e52..85836bbb 100644 --- a/backend/entityservice/tests/util.py +++ b/backend/entityservice/tests/util.py @@ -4,6 +4,7 @@ import math import os import random +import struct import time import tempfile from contextlib import contextmanager @@ -54,6 +55,10 @@ def generate_clks(count, size): return res +def generate_clks_with_id(count, size): + return zip(range(count), generate_clks(count, size)) + + def generate_json_serialized_clks(count, size=128): clks = generate_clks(count, size) return [serialize_bytes(hash_bytes) for hash_bytes in clks] @@ -436,3 +441,15 @@ def upload_binary_data_from_file(requests, file_path, project_id, token, count, def get_expected_number_parties(project_params): return project_params.get('number_parties', 2) + + +def binary_upload_format(encoding_size): + bit_packing_fmt = f"!{encoding_size}s" + bit_packing_struct = struct.Struct(bit_packing_fmt) + return bit_packing_struct + + +def binary_pack_for_upload(filters, encoding_size): + bit_packing_struct = binary_upload_format(encoding_size) + for hash_bytes in filters: + yield bit_packing_struct.pack(hash_bytes) diff --git a/backend/entityservice/views/project.py b/backend/entityservice/views/project.py index 86d6d509..407611ce 100644 --- a/backend/entityservice/views/project.py +++ b/backend/entityservice/views/project.py @@ -10,7 +10,7 @@ from entityservice.tasks import handle_raw_upload, check_for_executable_runs, remove_project from entityservice.tracing import serialize_span from entityservice.utils import safe_fail_request, get_json, generate_code, get_stream, \ - clks_uploaded_to_project, fmt_bytes + clks_uploaded_to_project, fmt_bytes, iterable_to_stream from entityservice.database import DBConn from entityservice.views.auth_checks import abort_if_project_doesnt_exist, abort_if_invalid_dataprovider_token, \ abort_if_invalid_results_token, get_authorization_token_type_or_abort @@ -152,13 +152,20 @@ def project_binaryclks_post(project_id): # https://github.com/zalando/connexion/issues/592 # stream = get_stream() stream = BytesIO(request.data) - expected_bytes = binary_format(size).size * count + binary_formatter = binary_format(size) + + def entity_id_injector(filter_stream): + for entity_id in range(count): + yield binary_formatter.pack(entity_id, filter_stream.read(size)) + + data_with_ids = b''.join(entity_id_injector(stream)) + expected_bytes = size * count log.debug(f"Stream size is {len(request.data)} B, and we expect {expected_bytes} B") if len(request.data) != expected_bytes: safe_fail_request(400, "Uploaded data did not match the expected size. Check request headers are correct") try: - receipt_token = upload_clk_data_binary(project_id, dp_id, stream, count, size) + receipt_token = upload_clk_data_binary(project_id, dp_id, BytesIO(data_with_ids), count, size) except ValueError: safe_fail_request(400, "Uploaded data did not match the expected size. Check request headers are correct.") @@ -174,6 +181,7 @@ def project_binaryclks_post(project_id): db.set_dataprovider_upload_state(conn, dp_id, state='done') return {'message': 'Updated', 'receipt_token': receipt_token}, 201 + def project_clks_post(project_id): """ Update a project to provide encoded PII data.