Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backend/entityservice/cache/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ def get_deserialized_filter(dp_id):
else:
logger.debug("Looking up popcounts and filename from database")
with DBConn() as db:
serialized_filters_file = get_filter_metadata(db, dp_id)

serialized_filters_file, encoding_size = get_filter_metadata(db, dp_id)
mc = connect_to_object_store()
logger.debug("Getting filters from object store")

# Note this uses already calculated popcounts unlike
# serialization.deserialize_filters()
raw_data_response = mc.get_object(config.MINIO_BUCKET, serialized_filters_file)
python_filters = binary_unpack_filters(raw_data_response)
python_filters = binary_unpack_filters(raw_data_response.stream(encoding_size))

set_deserialized_filter(dp_id, python_filters)
return python_filters
Expand Down
17 changes: 9 additions & 8 deletions backend/entityservice/database/selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,32 +260,33 @@ def get_dataprovider_id(db, update_token):
return query_db(db, sql_query, [update_token], one=True)['id']


def get_bloomingdata_column(db, dp_id, column):
assert column in {'ts', 'token', 'file', 'state', 'count'}
def get_bloomingdata_columns(db, dp_id, columns):
for column in columns:
assert column in {'ts', 'token', 'file', 'state', 'count', 'encoding_size'}
sql_query = """
SELECT {}
FROM bloomingdata
WHERE dp = %s
""".format(column)
""".format(', '.join(columns))
result = query_db(db, sql_query, [dp_id], one=True)
if result is None:
raise DataProviderDeleted(dp_id)
else:
return result[column]
return [result[column] for column in columns]


def get_filter_metadata(db, dp_id):
"""
:return: The filename of the raw clks.
:return: The filename and the encoding size of the raw clks.
"""
return get_bloomingdata_column(db, dp_id, 'file').strip()
filename, encoding_size = get_bloomingdata_columns(db, dp_id, ['file', 'encoding_size'])
return filename.strip(), encoding_size


def get_number_of_hashes(db, dp_id):
"""
:return: The count of the uploaded encodings.
"""
return get_bloomingdata_column(db, dp_id, 'count')
return get_bloomingdata_columns(db, dp_id, ['count'])[0]


def get_project_schema_encoding_size(db, project_id):
Expand Down
2 changes: 1 addition & 1 deletion backend/entityservice/error_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def check_dataproviders_encoding(project_id, encoding_size):

def handle_invalid_encoding_data(project_id, dp_id):
with DBConn() as conn:
filename = get_filter_metadata(conn, dp_id)
filename, _ = get_filter_metadata(conn, dp_id)
update_encoding_metadata(conn, 'DELETED', dp_id, state='error')
delete_minio_objects.delay([filename], project_id)
35 changes: 25 additions & 10 deletions backend/entityservice/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ def binary_format(encoding_size):

The binary format string can be understood as:
- "!" Use network byte order (big-endian).
- "I" store the entity ID as an unsigned int
- "<encoding size>s" Store the n (e.g. 128) raw bytes of the bitarray

https://docs.python.org/3/library/struct.html

:param encoding_size: the encoding size of one filter in number of bytes, excluding the entity ID info
:return:
A Struct object which can read and write the binary format.
"""
bit_packing_fmt = f"!{encoding_size}s"
bit_packing_fmt = f"!I{encoding_size}s"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are just going to pretend the old format doesn't exist aren't we? Should we take the opportunity to add a version byte to the format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an internal format, so yes, I pretend that it never existed. I'm old, I forget quickly.
A version byte for each filter would mean a lot of version bytes in the object store. This should rather be solved via a header. But again, since this is only used internally in the service, we will never have to differentiate between the versions anyway and thus, the versioning is redundant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was meaning a header for the file rather than for each encoding. But sure not worth it right now.

bit_packing_struct = struct.Struct(bit_packing_fmt)
return bit_packing_struct

Expand All @@ -54,31 +59,41 @@ def binary_pack_filters(filters, encoding_size):
"""Efficient packing of bloomfilters.

:param filters:
An iterable of bytes as produced by deserialize_bytes.

An iterable of tuples, with
- first element is the entity ID as an unsigned int
- second element is 'encoding_size' number of bytes as produced by deserialize_bytes.
:param encoding_size: the encoding size of one filter in number of bytes, excluding the entity ID info
:return:
An iterable of bytes.
"""
bit_packing_struct = binary_format(encoding_size)

for hash_bytes in filters:
yield bit_packing_struct.pack(hash_bytes)
yield bit_packing_struct.pack(*hash_bytes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⭐️



def binary_unpack_one(data, bit_packing_struct):
clk_bytes, = bit_packing_struct.unpack(data)
return clk_bytes
entity_id, clk_bytes, = bit_packing_struct.unpack(data)
return entity_id, clk_bytes


def binary_unpack_filters(streamable_data, max_bytes=None, encoding_size=None):
def binary_unpack_filters(data_iterable, max_bytes=None, encoding_size=None):
"""
Unpack filters that were packed with the 'binary_pack_filters' method.

:param data_iterable: an iterable of binary packed filters.
:param max_bytes: if present, only read up to 'max_bytes' bytes.
:param encoding_size: the encoding size of one filter in number of bytes, excluding the entity ID info
:return: list of filters with their corresponding entity IDs as a list of tuples.
"""
assert encoding_size is not None
bit_packed_element = binary_format(encoding_size)
bit_packed_element_size = bit_packed_element.size
filters = []
bytes_consumed = 0

logger.info(f"Unpacking stream of encodings with size {encoding_size} - packed as {bit_packed_element_size}")
for raw_bytes in streamable_data.stream(bit_packed_element_size):
logger.info(f"Iterating over encodings of size {encoding_size} - packed as {bit_packed_element_size}")
for raw_bytes in data_iterable:
filters.append(binary_unpack_one(raw_bytes, bit_packed_element))

bytes_consumed += bit_packed_element_size
Expand Down Expand Up @@ -185,6 +200,6 @@ def get_chunk_from_object_store(chunk_info, encoding_size=128):
bit_packed_element_size * chunk_range_start,
chunk_bytes)

chunk_data = binary_unpack_filters(chunk_stream, chunk_bytes, encoding_size)
chunk_data = binary_unpack_filters(chunk_stream.stream(bit_packed_element_size), chunk_bytes, encoding_size)

return chunk_data, chunk_length
8 changes: 5 additions & 3 deletions backend/entityservice/tasks/comparing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def create_comparison_jobs(project_id, run_id, parent_span=None):
current_span.log_kv({"event": 'get-dataset-sizes'})

filters_object_filenames = tuple(
get_filter_metadata(conn, dp_id) for dp_id in dp_ids)
get_filter_metadata(conn, dp_id)[0] for dp_id in dp_ids)
current_span.log_kv({"event": 'get-metadata'})

log.debug("Chunking computation task")

chunk_infos = tuple(anonlink.concurrency.split_to_chunks(
Expand Down Expand Up @@ -125,10 +124,13 @@ def compute_filter_similarity(chunk_info, project_id, run_id, threshold, encodin
t0 = time.time()
log.debug("Fetching and deserializing chunk of filters for dataprovider 1")
chunk_dp1, chunk_dp1_size = get_chunk_from_object_store(chunk_info_dp1, encoding_size)

#TODO: use the entity ids!
entity_ids_dp1, chunk_dp1 = zip(*chunk_dp1)
t1 = time.time()
log.debug("Fetching and deserializing chunk of filters for dataprovider 2")
chunk_dp2, chunk_dp2_size = get_chunk_from_object_store(chunk_info_dp2, encoding_size)
# TODO: use the entity ids!
entity_ids_dp2, chunk_dp2 = zip(*chunk_dp2)
Comment on lines +127 to +133
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(what you said)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but this is another task. I want this PR to be single purpose.
Integrating the use of the entity IDs will most likely also require changes to the helper functions in the anonlink library. I didn't want to open that can of worm just yet.

t2 = time.time()
span.log_kv({'event': 'chunks are fetched and deserialized'})
log.debug("Calculating filter similarity")
Expand Down
5 changes: 2 additions & 3 deletions backend/entityservice/tasks/encoding_uploading.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ def handle_raw_upload(project_id, dp_id, receipt_token, parent_span=None):
uploaded_encoding_size = len(first_hash_bytes)

def filter_generator():
log.debug("Deserializing json filters")
i = 0
yield first_hash_bytes
yield i, first_hash_bytes
for i, line in enumerate(text_stream, start=1):
hash_bytes = deserialize_bytes(line)
if len(hash_bytes) != uploaded_encoding_size:
raise ValueError("Encodings were not all the same size")
yield hash_bytes
yield i, hash_bytes

log.info(f"Processed {i + 1} hashes")

Expand Down
7 changes: 3 additions & 4 deletions backend/entityservice/tests/test_project_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import os
import pytest

from entityservice.serialization import binary_pack_filters
from entityservice.tests.config import url
from entityservice.tests.util import (
create_project_upload_data, create_project_upload_fake_data,
generate_clks, generate_json_serialized_clks,
get_expected_number_parties, get_run_result, post_run,
upload_binary_data, upload_binary_data_from_file)
upload_binary_data, upload_binary_data_from_file, binary_pack_for_upload)


def test_project_single_party_data_uploaded(requests, valid_project_params):
Expand Down Expand Up @@ -73,13 +72,13 @@ def test_project_binary_data_upload_with_different_encoded_size(
**valid_project_params
}).json()

common = next(binary_pack_filters(generate_clks(1, encoding_size),
common = next(binary_pack_for_upload(generate_clks(1, encoding_size),
encoding_size))

data = []
for i in range(expected_number_parties):
generated_clks = generate_clks(499, encoding_size)
packed_clks = binary_pack_filters(generated_clks, encoding_size)
packed_clks = binary_pack_for_upload(generated_clks, encoding_size)
packed_joined = b''.join(packed_clks)
packed_with_common = (
packed_joined + common if i == 0 else common + packed_joined)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import anonlink

from entityservice.serialization import binary_pack_filters
from entityservice.tests.util import (
create_project_upload_data, delete_project, get_run_result, post_run)
create_project_upload_data, delete_project, get_run_result, post_run, binary_pack_for_upload)

DATA_FILENAME = 'test-multiparty-results-correctness-data.pkl'
DATA_PATH = pathlib.Path(__file__).parent / 'testdata' / DATA_FILENAME
Expand All @@ -29,7 +28,7 @@ def test_groups_correctness(requests):
filter_size = len(filters[0][0])
assert all(len(filter_) == filter_size
for dataset in filters for filter_ in dataset)
packed_filters = [b''.join(binary_pack_filters(f, filter_size))
packed_filters = [b''.join(binary_pack_for_upload(f, filter_size))
for f in filters]
project_data, _ = create_project_upload_data(
requests, packed_filters, result_type='groups',
Expand Down
25 changes: 21 additions & 4 deletions backend/entityservice/tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import io
import unittest
import random

import json
import random
import unittest
from array import array

import anonlink

from entityservice.serialization import deserialize_bytes, generate_scores
from entityservice.serialization import deserialize_bytes, generate_scores, binary_pack_filters, \
binary_unpack_filters, binary_unpack_one, binary_format
from entityservice.tests.util import serialize_bytes, generate_bytes


Expand Down Expand Up @@ -69,6 +69,23 @@ def test_sims_to_json_empty(self):
self.assertIn('similarity_scores', json_obj)
assert len(json_obj["similarity_scores"]) == 0

def test_binary_pack_filters(self):
encoding_size = 128
filters = [(random.randint(0, 2 ** 32 - 1), generate_bytes(encoding_size)) for _ in range(10)]
packed_filters = binary_pack_filters(filters, encoding_size)
bin_format = binary_format(encoding_size)
for filter, packed_filter in zip(filters, packed_filters):
assert len(packed_filter) == encoding_size + 4
unpacked = binary_unpack_one(packed_filter, bin_format)
assert filter == unpacked

def test_binary_unpack_filters(self):
encoding_size = 128
filters = [(random.randint(0, 2 ** 32 - 1), generate_bytes(encoding_size)) for _ in range(10)]
laundered_filters = binary_unpack_filters(binary_pack_filters(filters, encoding_size),
encoding_size=encoding_size)
assert filters == laundered_filters


if __name__ == "__main__":
unittest.main()
17 changes: 17 additions & 0 deletions backend/entityservice/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import random
import struct
import time
import tempfile
from contextlib import contextmanager
Expand Down Expand Up @@ -54,6 +55,10 @@ def generate_clks(count, size):
return res


def generate_clks_with_id(count, size):
return zip(range(count), generate_clks(count, size))


def generate_json_serialized_clks(count, size=128):
clks = generate_clks(count, size)
return [serialize_bytes(hash_bytes) for hash_bytes in clks]
Expand Down Expand Up @@ -436,3 +441,15 @@ def upload_binary_data_from_file(requests, file_path, project_id, token, count,

def get_expected_number_parties(project_params):
return project_params.get('number_parties', 2)


def binary_upload_format(encoding_size):
bit_packing_fmt = f"!{encoding_size}s"
bit_packing_struct = struct.Struct(bit_packing_fmt)
return bit_packing_struct


def binary_pack_for_upload(filters, encoding_size):
bit_packing_struct = binary_upload_format(encoding_size)
for hash_bytes in filters:
yield bit_packing_struct.pack(hash_bytes)
14 changes: 11 additions & 3 deletions backend/entityservice/views/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from entityservice.tasks import handle_raw_upload, check_for_executable_runs, remove_project
from entityservice.tracing import serialize_span
from entityservice.utils import safe_fail_request, get_json, generate_code, get_stream, \
clks_uploaded_to_project, fmt_bytes
clks_uploaded_to_project, fmt_bytes, iterable_to_stream
from entityservice.database import DBConn
from entityservice.views.auth_checks import abort_if_project_doesnt_exist, abort_if_invalid_dataprovider_token, \
abort_if_invalid_results_token, get_authorization_token_type_or_abort
Expand Down Expand Up @@ -152,13 +152,20 @@ def project_binaryclks_post(project_id):
# https://github.com/zalando/connexion/issues/592
# stream = get_stream()
stream = BytesIO(request.data)
expected_bytes = binary_format(size).size * count
binary_formatter = binary_format(size)

def entity_id_injector(filter_stream):
for entity_id in range(count):
yield binary_formatter.pack(entity_id, filter_stream.read(size))

data_with_ids = b''.join(entity_id_injector(stream))
expected_bytes = size * count
log.debug(f"Stream size is {len(request.data)} B, and we expect {expected_bytes} B")
if len(request.data) != expected_bytes:
safe_fail_request(400,
"Uploaded data did not match the expected size. Check request headers are correct")
try:
receipt_token = upload_clk_data_binary(project_id, dp_id, stream, count, size)
receipt_token = upload_clk_data_binary(project_id, dp_id, BytesIO(data_with_ids), count, size)
except ValueError:
safe_fail_request(400,
"Uploaded data did not match the expected size. Check request headers are correct.")
Expand All @@ -174,6 +181,7 @@ def project_binaryclks_post(project_id):
db.set_dataprovider_upload_state(conn, dp_id, state='done')
return {'message': 'Updated', 'receipt_token': receipt_token}, 201


def project_clks_post(project_id):
"""
Update a project to provide encoded PII data.
Expand Down