Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions backend/entityservice/database/insertions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import psycopg2
import psycopg2.extras

Expand Down Expand Up @@ -42,17 +44,65 @@ def insert_dataprovider(cur, auth_token, project_id):
return execute_returning_id(cur, sql_query, [project_id, auth_token])


def insert_encoding_metadata(db, clks_filename, dp_id, receipt_token, count):
def insert_blocking_metadata(db, dp_id, blocks):
"""
Insert a new entry into the blocks table.

:param blocks: A dict mapping block id to the number of encodings per block.
"""
logger.info("Adding blocking metadata to database")
sql_insertion_query = """
INSERT INTO blocks
(dp, block_name, count, state)
VALUES %s
"""

logger.info("Preparing SQL for bulk insert of blocks")
values = [(dp_id, block_id, blocks[block_id], 'pending') for block_id in blocks]

with db.cursor() as cur:
psycopg2.extras.execute_values(cur, sql_insertion_query, values)


def insert_encoding_metadata(db, clks_filename, dp_id, receipt_token, encoding_count, block_count):
logger.info("Adding metadata on encoded entities to database")
sql_insertion_query = """
INSERT INTO bloomingdata
(dp, token, file, count, state)
INSERT INTO uploads
(dp, token, file, count, block_count, state)
VALUES
(%s, %s, %s, %s, %s)
(%s, %s, %s, %s, %s, %s)
"""

with db.cursor() as cur:
cur.execute(sql_insertion_query, [dp_id, receipt_token, clks_filename, count, 'pending'])
cur.execute(sql_insertion_query, [dp_id, receipt_token, clks_filename, encoding_count, block_count, 'pending'])


def insert_encodings_into_blocks(db, dp_id: int, block_ids: List[List[str]], encoding_ids: List[int],
encodings: List[bytes], page_size: int = 4096):
"""
Bulk load blocking and encoding data into the database.
See https://hakibenita.com/fast-load-data-python-postgresql#copy-data-from-a-string-iterator-with-buffer-size

:param page_size:
Maximum number of rows to fetch in a given sql statement/network transfer. A larger page size
will require more local memory, but could be faster due to less network transfers.

"""
encodings_insertion_query = "INSERT INTO encodings (dp, encoding_id, encoding) VALUES %s"
blocks_insertion_query = "INSERT INTO encodingblocks (dp, encoding_id, block_id) VALUES %s"
encoding_data = ((dp_id, eid, encoding) for eid, encoding in zip(encoding_ids, encodings))

def block_data_generator(encoding_ids, block_ids):
for eid, block_ids in zip(encoding_ids, block_ids):
for block_id in block_ids:
yield (dp_id, eid, block_id)

with db.cursor() as cur:
psycopg2.extras.execute_values(cur, encodings_insertion_query, encoding_data, page_size=page_size)
psycopg2.extras.execute_values(cur,
blocks_insertion_query,
block_data_generator(encoding_ids, block_ids),
page_size=page_size)


def set_dataprovider_upload_state(db, dp_id, state='error'):
Expand Down Expand Up @@ -130,7 +180,7 @@ def insert_permutation_mask(conn, project_id, run_id, mask_list):

def update_encoding_metadata(db, clks_filename, dp_id, state):
sql_query = """
UPDATE bloomingdata
UPDATE uploads
SET
state = %s,
file = %s
Expand All @@ -149,7 +199,7 @@ def update_encoding_metadata(db, clks_filename, dp_id, state):

def update_encoding_metadata_set_encoding_size(db, dp_id, encoding_size):
sql_query = """
UPDATE bloomingdata
UPDATE uploads
SET
encoding_size = %s
WHERE
Expand Down
84 changes: 54 additions & 30 deletions backend/entityservice/database/selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def select_dataprovider_id(db, project_id, receipt_token):
Returns None if token is incorrect.
"""
sql_query = """
SELECT dp from dataproviders, bloomingdata
SELECT dp from dataproviders, uploads
WHERE
bloomingdata.dp = dataproviders.id AND
uploads.dp = dataproviders.id AND
dataproviders.project = %s AND
bloomingdata.token = %s
uploads.token = %s
"""
query_result = query_db(db, sql_query, [project_id, receipt_token], one=True)
logger.debug("Looking up data provider with auth. {}".format(query_result))
Expand Down Expand Up @@ -61,10 +61,10 @@ def check_run_exists(db, project_id, run_id):
def get_number_parties_uploaded(db, project_id):
sql_query = """
SELECT COUNT(*)
FROM dataproviders, bloomingdata
FROM dataproviders, uploads
WHERE
dataproviders.project = %s AND
bloomingdata.dp = dataproviders.id AND
uploads.dp = dataproviders.id AND
dataproviders.uploaded = 'done'
"""
query_result = query_db(db, sql_query, [project_id], one=True)
Expand All @@ -77,24 +77,24 @@ def get_encoding_error_count(db, project_id):
"""
sql_query = """
SELECT count(*)
FROM dataproviders, bloomingdata
FROM dataproviders, uploads
WHERE
dataproviders.project = %s AND
bloomingdata.dp = dataproviders.id AND
bloomingdata.state = 'error'
uploads.dp = dataproviders.id AND
uploads.state = 'error'
"""
return query_db(db, sql_query, [project_id], one=True)['count']


def get_number_parties_ready(db, resource_id):
sql_query = """
SELECT COUNT(*)
FROM dataproviders, bloomingdata
FROM dataproviders, uploads
WHERE
dataproviders.project = %s AND
bloomingdata.dp = dataproviders.id AND
uploads.dp = dataproviders.id AND
dataproviders.uploaded = 'done' AND
bloomingdata.state = 'ready'
uploads.state = 'ready'
"""
query_result = query_db(db, sql_query, [resource_id], one=True)
return query_result['count']
Expand Down Expand Up @@ -187,11 +187,12 @@ def get_run_result(db, resource_id):


def get_project_dataset_sizes(db, project_id):
"""Returns the number of encodings in a dataset."""
sql_query = """
SELECT bloomingdata.count
FROM dataproviders, bloomingdata
SELECT uploads.count
FROM dataproviders, uploads
WHERE
bloomingdata.dp=dataproviders.id AND
uploads.dp=dataproviders.id AND
dataproviders.project=%s
ORDER BY dataproviders.id
"""
Expand All @@ -203,9 +204,9 @@ def get_project_dataset_sizes(db, project_id):
def get_uploaded_encoding_sizes(db, project_id):
sql_query = """
SELECT dp, encoding_size
FROM dataproviders, bloomingdata
FROM dataproviders, uploads
WHERE
bloomingdata.dp=dataproviders.id AND
uploads.dp=dataproviders.id AND
dataproviders.project=%s
ORDER BY dataproviders.id
"""
Expand All @@ -215,10 +216,10 @@ def get_uploaded_encoding_sizes(db, project_id):

def get_smaller_dataset_size_for_project(db, project_id):
sql_query = """
SELECT MIN(bloomingdata.count) as smaller
FROM dataproviders, bloomingdata
SELECT MIN(uploads.count) as smaller
FROM dataproviders, uploads
WHERE
bloomingdata.dp=dataproviders.id AND
uploads.dp=dataproviders.id AND
dataproviders.project=%s
"""
query_result = query_db(db, sql_query, [project_id], one=True)
Expand All @@ -231,10 +232,10 @@ def get_total_comparisons_for_project(db, project_id):
"""
expected_datasets = get_project_column(db, project_id, 'parties')
sql_query = """
SELECT bloomingdata.count as rows
from dataproviders, bloomingdata
SELECT uploads.count as rows
from dataproviders, uploads
where
bloomingdata.dp=dataproviders.id AND
uploads.dp=dataproviders.id AND
dataproviders.project=%s
"""
query_results = query_db(db, sql_query, [project_id])
Expand All @@ -260,12 +261,12 @@ def get_dataprovider_id(db, update_token):
return query_db(db, sql_query, [update_token], one=True)['id']


def get_bloomingdata_columns(db, dp_id, columns):
def get_uploads_columns(db, dp_id, columns):
for column in columns:
assert column in {'ts', 'token', 'file', 'state', 'count', 'encoding_size'}
assert column in {'ts', 'token', 'file', 'state', 'block_count', 'count', 'encoding_size'}
sql_query = """
SELECT {}
FROM bloomingdata
FROM uploads
WHERE dp = %s
""".format(', '.join(columns))
result = query_db(db, sql_query, [dp_id], one=True)
Expand All @@ -274,19 +275,42 @@ def get_bloomingdata_columns(db, dp_id, columns):
return [result[column] for column in columns]


def get_encodingblock_ids(db, dp_id, block_name=None):
"""Yield all encoding ids in either a single block, or all blocks for a given data provider."""
sql_query = """
SELECT encoding_id
FROM encodingblocks
WHERE dp = %s
{}
""".format("AND block_id = %s" if block_name else "")
# Specifying a name for the cursor creates a server-side cursor, which prevents all of the
# records from being downloaded at once.
cur = db.cursor(f'encodingfetcher-{dp_id}')

args = (dp_id, block_name) if block_name else (dp_id,)

cur.execute(sql_query, args)
while True:
rows = cur.fetchmany(10_000)
if not rows:
break
for row in rows:
yield row[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait a minute, does that mean we will hold on to the db connection until all the yielding is done?
Won't you call this like:

with DBConn() as db:
   for id in get_encodingblock_ids(...):
      ...

That doesn't seem like a good idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you think it would be good to keep the db connection while streaming through the blocks? Establishing a db connection is not free. The point of this sneaky Python cache is because we might not have the memory to store all (e.g. millions) of blocks if we used fetchall(), but the network overhead of fetchone if we fetched and yielded each row one at a time would be a killer.



def get_filter_metadata(db, dp_id):
"""
:return: The filename and the encoding size of the raw clks.
"""
filename, encoding_size = get_bloomingdata_columns(db, dp_id, ['file', 'encoding_size'])
filename, encoding_size = get_uploads_columns(db, dp_id, ['file', 'encoding_size'])
return filename.strip(), encoding_size


def get_number_of_hashes(db, dp_id):
def get_encoding_metadata(db, dp_id):
"""
:return: The count of the uploaded encodings.
:return: The number of encodings and number of blocks of the uploaded data.
"""
return get_bloomingdata_columns(db, dp_id, ['count'])[0]
return get_uploads_columns(db, dp_id, ['count', 'block_count'])


def get_project_schema_encoding_size(db, project_id):
Expand Down Expand Up @@ -370,7 +394,7 @@ def get_all_objects_for_project(db, project_id):

for dp in dps:
clk_file_ref = query_db(db, """
SELECT file FROM bloomingdata
SELECT file FROM uploads
WHERE dp = %s
""", [dp['id']], one=True)

Expand Down
Loading