-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathcomparing.py
More file actions
355 lines (292 loc) · 14.2 KB
/
comparing.py
File metadata and controls
355 lines (292 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import array
import heapq
import itertools
import operator
import time
import anonlink
import minio
import opentracing
import psycopg2
from celery import chord
from entityservice.async_worker import celery, logger
from entityservice.cache.encodings import remove_from_cache
from entityservice.cache.progress import save_current_progress
from entityservice.encoding_storage import get_encoding_chunk
from entityservice.errors import RunDeleted, InactiveRun
from entityservice.database import (
check_project_exists, check_run_exists, DBConn, get_dataprovider_ids,
get_project_column, get_project_dataset_sizes,
get_project_encoding_size, get_run, insert_similarity_score_file,
update_run_mark_failure, get_block_metadata)
from entityservice.models.run import progress_run_stage as progress_stage
from entityservice.object_store import connect_to_object_store
from entityservice.settings import Config
from entityservice.tasks.base_task import TracedTask, celery_bug_fix, on_chord_error
from entityservice.tasks.solver import solver_task
from entityservice.tasks import mark_run_complete
from entityservice.tasks.assert_valid_run import assert_valid_run
from entityservice.utils import generate_code, iterable_to_stream
def check_run_active(conn, project_id, run_id):
"""Raises InactiveRun if the project or run has been deleted from the database.
"""
if not check_project_exists(conn, project_id) or not check_run_exists(conn, project_id, run_id):
raise InactiveRun("Skipping as project or run not found in database.")
@celery.task(base=TracedTask, ignore_result=True, args_as_tags=('project_id', 'run_id'))
def create_comparison_jobs(project_id, run_id, parent_span=None):
"""Schedule all the entity comparisons as sub tasks for a run.
At a high level this task:
- checks if the project and run have been deleted and if so aborts.
- retrieves metadata: the number and size of the datasets, the encoding size,
and the number and size of blocks.
- splits the work into independent "chunks" and schedules them to run in celery
- schedules the follow up task to run after all the comparisons have been computed.
"""
log = logger.bind(pid=project_id, run_id=run_id)
current_span = create_comparison_jobs.span
with DBConn() as conn:
check_run_active(conn, project_id, run_id)
dp_ids = get_dataprovider_ids(conn, project_id)
number_of_datasets = len(dp_ids)
assert number_of_datasets >= 2, "Expected at least 2 data providers"
log.info(f"Scheduling comparison of CLKs from data provider ids: "
f"{', '.join(map(str, dp_ids))}")
# Retrieve required metadata
dataset_sizes = get_project_dataset_sizes(conn, project_id)
# {dp_id -> {block_id -> block_size}}
# e.g. {33: {'1': 100}, 34: {'1': 100}, 35: {'1': 100}}
dp_block_sizes = {}
for dp_id in dp_ids:
dp_block_sizes[dp_id] = dict(get_block_metadata(conn, dp_id))
log.info("Finding blocks in common between dataproviders")
# block_id -> List(pairs of dp ids)
# e.g. {'1': [(26, 27), (26, 28), (27, 28)]}
blocks = {}
for dp1, dp2 in itertools.combinations(dp_ids, 2):
# Get the intersection of blocks between these two dataproviders
common_block_ids = set(dp_block_sizes[dp1]).intersection(set(dp_block_sizes[dp2]))
for block_id in common_block_ids:
blocks.setdefault(block_id, []).append((dp1, dp2))
if len(dataset_sizes) < 2:
log.warning("Unexpected number of dataset sizes in db. Stopping")
update_run_mark_failure(conn, run_id)
return
encoding_size = get_project_encoding_size(conn, project_id)
log.info(f"Computing similarity for "
f"{' x '.join(map(str, dataset_sizes))} entities")
current_span.log_kv({"event": 'get-dataset-sizes', 'sizes': dataset_sizes})
# Pass the threshold to the comparison tasks to minimize their db lookups
threshold = get_run(conn, run_id)['threshold']
log.debug("Chunking computation task")
chunk_infos = tuple(anonlink.concurrency.split_to_chunks(
Config.CHUNK_SIZE_AIM,
dataset_sizes=dataset_sizes))
# Add the db ids to the chunk information.
for chunk_info in chunk_infos:
for chunk_dp_info in chunk_info:
chunk_dp_index = chunk_dp_info['datasetIndex']
chunk_dp_info['dataproviderId'] = dp_ids[chunk_dp_index]
log.info(f"Chunking into {len(chunk_infos)} computation tasks")
current_span.log_kv({"event": "chunking", 'num_chunks': len(chunk_infos)})
span_serialized = create_comparison_jobs.get_serialized_span()
# Prepare the Celery Chord that will compute all the similarity scores:
scoring_tasks = [compute_filter_similarity.si(
chunk_info,
project_id,
run_id,
threshold,
encoding_size,
span_serialized
) for chunk_info in chunk_infos]
if len(scoring_tasks) == 1:
scoring_tasks.append(celery_bug_fix.si())
callback_task = aggregate_comparisons.s(project_id, run_id, parent_span=span_serialized).on_error(
on_chord_error.s(run_id=run_id))
future = chord(scoring_tasks)(callback_task)
@celery.task(base=TracedTask, args_as_tags=('project_id', 'run_id', 'threshold'))
def compute_filter_similarity(chunk_info, project_id, run_id, threshold, encoding_size, parent_span=None):
"""Compute filter similarity between a chunk of filters in dataprovider 1,
and a chunk of filters in dataprovider 2.
:param dict chunk_info:
A chunk returned by ``anonlink.concurrency.split_to_chunks``.
:param project_id:
:param run_id:
:param threshold:
:param encoding_size: The size in bytes of each encoded entry
:param parent_span: A serialized opentracing span context.
:returns A 3-tuple: (num_results, result size in bytes, results_filename_in_object_store, )
"""
log = logger.bind(pid=project_id, run_id=run_id)
task_span = compute_filter_similarity.span
def new_child_span(name):
return compute_filter_similarity.tracer.start_active_span(
name,
child_of=compute_filter_similarity.span)
log.debug("Computing similarity for a chunk of filters")
log.debug("Checking that the resource exists (in case of run being canceled/deleted)")
assert_valid_run(project_id, run_id, log)
chunk_info_dp1, chunk_info_dp2 = chunk_info
with DBConn() as conn:
with new_child_span('fetching-encodings'):
log.debug("Fetching and deserializing chunk of filters for dataprovider 1")
chunk_with_ids_dp1, chunk_dp1_size = get_encoding_chunk(conn, chunk_info_dp1, encoding_size)
#TODO: use the entity ids!
entity_ids_dp1, chunk_dp1 = zip(*chunk_with_ids_dp1)
log.debug("Fetching and deserializing chunk of filters for dataprovider 2")
chunk_with_ids_dp2, chunk_dp2_size = get_encoding_chunk(conn, chunk_info_dp2, encoding_size)
# TODO: use the entity ids!
entity_ids_dp2, chunk_dp2 = zip(*chunk_with_ids_dp2)
log.debug('Both chunks are fetched and deserialized')
task_span.log_kv({'size1': chunk_dp1_size, 'size2': chunk_dp2_size})
with new_child_span('comparing-encodings'):
log.debug("Calculating filter similarity")
try:
chunk_results = anonlink.concurrency.process_chunk(
chunk_info,
(chunk_dp1, chunk_dp2),
anonlink.similarities.dice_coefficient_accelerated,
threshold,
k=min(chunk_dp1_size, chunk_dp2_size))
except NotImplementedError as e:
log.warning("Encodings couldn't be compared using anonlink.")
return
log.debug('Encoding similarities calculated')
with new_child_span('update-comparison-progress'):
# Update the number of comparisons completed
comparisons_computed = chunk_dp1_size * chunk_dp2_size
save_current_progress(comparisons_computed, run_id)
with new_child_span('save-comparison-results-to-minio'):
sims, _, _ = chunk_results
num_results = len(sims)
if num_results:
result_filename = Config.SIMILARITY_SCORES_FILENAME_FMT.format(
generate_code(12))
task_span.log_kv({"edges": num_results})
log.info("Writing {} intermediate results to file: {}".format(num_results, result_filename))
bytes_iter, file_size \
= anonlink.serialization.dump_candidate_pairs_iter(chunk_results)
iter_stream = iterable_to_stream(bytes_iter)
mc = connect_to_object_store()
try:
mc.put_object(
Config.MINIO_BUCKET, result_filename, iter_stream, file_size)
except minio.ResponseError as err:
log.warning("Failed to store result in minio")
raise
else:
result_filename = None
file_size = None
log.info("Comparisons: {}, Links above threshold: {}".format(comparisons_computed, len(chunk_results)))
return num_results, file_size, result_filename
def _put_placeholder_empty_file(mc, log):
sims = array.array('d')
dset_is0 = array.array('I')
dset_is1 = array.array('I')
rec_is0 = array.array('I')
rec_is1 = array.array('I')
candidate_pairs = sims, (dset_is0, dset_is1), (rec_is0, rec_is1)
empty_file_iter, empty_file_size \
= anonlink.serialization.dump_candidate_pairs_iter(candidate_pairs)
empty_file_name = Config.SIMILARITY_SCORES_FILENAME_FMT.format(
generate_code(12))
empty_file_stream = iterable_to_stream(empty_file_iter)
try:
mc.put_object(Config.MINIO_BUCKET, empty_file_name,
empty_file_stream, empty_file_size)
except minio.ResponseError:
log.warning("Failed to store empty result in minio.")
raise
return 0, empty_file_size, empty_file_name
def _merge_files(mc, log, file0, file1):
num0, filesize0, filename0 = file0
num1, filesize1, filename1 = file1
total_num = num0 + num1
file0_stream = mc.get_object(Config.MINIO_BUCKET, filename0)
file1_stream = mc.get_object(Config.MINIO_BUCKET, filename1)
merged_file_iter, merged_file_size \
= anonlink.serialization.merge_streams_iter(
(file0_stream, file1_stream), sizes=(filesize0, filesize1))
merged_file_name = Config.SIMILARITY_SCORES_FILENAME_FMT.format(
generate_code(12))
merged_file_stream = iterable_to_stream(merged_file_iter)
try:
mc.put_object(Config.MINIO_BUCKET, merged_file_name,
merged_file_stream, merged_file_size)
except minio.ResponseError:
log.warning("Failed to store merged result in minio.")
raise
for del_err in mc.remove_objects(
Config.MINIO_BUCKET, (filename0, filename1)):
log.warning(f"Failed to delete result file "
f"{del_err.object_name}. {del_err}")
return total_num, merged_file_size, merged_file_name
def _insert_similarity_into_db(db, log, run_id, merged_filename):
try:
result_id = insert_similarity_score_file(
db, run_id, merged_filename)
except psycopg2.IntegrityError:
log.info("Error saving similarity score filename to database. "
"The project may have been deleted.")
raise RunDeleted(run_id)
log.debug(f"Saved path to similarity scores file to db with id "
f"{result_id}")
@celery.task(
base=TracedTask,
ignore_result=True,
autoretry_for=(minio.ResponseError,),
retry_backoff=True,
args_as_tags=('project_id', 'run_id'))
def aggregate_comparisons(similarity_result_files, project_id, run_id, parent_span=None):
log = logger.bind(pid=project_id, run_id=run_id)
if similarity_result_files is None:
raise TypeError("Inappropriate argument type - missing results files.")
files = []
for res in similarity_result_files:
if res is None:
log.warning("Missing results during aggregation. Stopping processing.")
raise TypeError("Inappropriate argument type - results missing at aggregation step.")
num, filesize, filename = res
if num:
assert filesize is not None
assert filename is not None
files.append((num, filesize, filename))
else:
assert filesize is None
assert filename is None
heapq.heapify(files)
log.debug(f"Aggregating result chunks from {len(files)} files, "
f"total size: {sum(map(operator.itemgetter(1), files))}")
mc = connect_to_object_store()
while len(files) > 1:
file0 = heapq.heappop(files)
file1 = heapq.heappop(files)
merged_file = _merge_files(mc, log, file0, file1)
heapq.heappush(files, merged_file)
if not files:
# No results. Let's chuck in an empty file.
empty_file = _put_placeholder_empty_file(mc, log)
files.append(empty_file)
(merged_num, merged_filesize, merged_filename), = files
log.info(f"Similarity score results in {merged_filename} in bucket "
f"{Config.MINIO_BUCKET} take up {merged_filesize} bytes.")
with DBConn() as db:
result_type = get_project_column(db, project_id, 'result_type')
_insert_similarity_into_db(db, log, run_id, merged_filename)
if result_type == "similarity_scores":
# Post similarity computation cleanup
dp_ids = get_dataprovider_ids(db, project_id)
else:
# we promote the run to the next stage
progress_stage(db, run_id)
dataset_sizes = get_project_dataset_sizes(db, project_id)
# DB now committed, we can fire off tasks that depend on the new db state
if result_type == "similarity_scores":
log.debug("Removing clk filters from redis cache")
for dp_id in dp_ids:
remove_from_cache(dp_id)
# Complete the run
log.info("Marking run as complete")
mark_run_complete.delay(run_id, aggregate_comparisons.get_serialized_span())
else:
solver_task.delay(
merged_filename, project_id, run_id, dataset_sizes,
aggregate_comparisons.get_serialized_span())