|
2 | 2 | import logging |
3 | 3 | import os |
4 | 4 | import re |
| 5 | +import zipfile |
5 | 6 | from datetime import timedelta |
6 | 7 | from operator import itemgetter |
7 | 8 | from random import randrange |
@@ -553,6 +554,83 @@ def get(self, request, *args, **kwargs): |
553 | 554 | user_submit_ip_logger = logging.getLogger('judge.user_submit_ip_logger') |
554 | 555 |
|
555 | 556 |
|
| 557 | +def validate_zip_content(zip_file, required_files, forbidden_files=None): |
| 558 | + try: |
| 559 | + with zipfile.ZipFile(zip_file) as z: |
| 560 | + files = set(z.namelist()) |
| 561 | + missing = [f for f in required_files if f not in files] |
| 562 | + if missing: |
| 563 | + return _('Missing files in zip: %s') % ', '.join(missing) |
| 564 | + |
| 565 | + if forbidden_files: |
| 566 | + present_forbidden = [f for f in forbidden_files if f in files] |
| 567 | + if present_forbidden: |
| 568 | + return _('Forbidden files in zip: %s') % ', '.join(present_forbidden) |
| 569 | + except zipfile.BadZipFile: |
| 570 | + return _('Invalid zip file') |
| 571 | + return None |
| 572 | + |
| 573 | + |
| 574 | +def validate_task1_public_test(zip_file): |
| 575 | + return validate_zip_content(zip_file, |
| 576 | + required_files=['public_captions_prediction_result.txt', 'task1_cap_training.py'], |
| 577 | + forbidden_files=['task1_cap_checkpoint.pt']) |
| 578 | + |
| 579 | + |
| 580 | +def validate_task1_final_public(zip_file): |
| 581 | + return validate_zip_content(zip_file, |
| 582 | + required_files=['public_captions_prediction_result.txt', 'task1_cap_inference.py', |
| 583 | + 'task1_cap_checkpoint.pt']) |
| 584 | + |
| 585 | + |
| 586 | +def validate_task1_private_test(zip_file): |
| 587 | + return validate_zip_content(zip_file, |
| 588 | + required_files=['private_captions_prediction_result.txt', 'task1_cap_training.py'], |
| 589 | + forbidden_files=['task1_cap_checkpoint.pt']) |
| 590 | + |
| 591 | + |
| 592 | +def validate_task1_final_private(zip_file): |
| 593 | + return validate_zip_content(zip_file, |
| 594 | + required_files=['private_captions_prediction_result.txt', 'task1_cap_inference.py', |
| 595 | + 'task1_cap_checkpoint.pt']) |
| 596 | + |
| 597 | + |
| 598 | +def validate_task2_public_test(zip_file): |
| 599 | + return validate_zip_content(zip_file, |
| 600 | + required_files=['public_activity.txt', 'attack_output.zip', 'task2_training.py'], |
| 601 | + forbidden_files=['model.pth']) |
| 602 | + |
| 603 | + |
| 604 | +def validate_task2_final_public(zip_file): |
| 605 | + return validate_zip_content(zip_file, |
| 606 | + required_files=['public_activity.txt', 'attack_output.zip', 'model.py', |
| 607 | + 'task2_inference.py', 'model.pth']) |
| 608 | + |
| 609 | + |
| 610 | +def validate_task2_private_test(zip_file): |
| 611 | + return validate_zip_content(zip_file, |
| 612 | + required_files=['private_activity.txt', 'attack_output.zip', 'task2_training.py'], |
| 613 | + forbidden_files=['model.pth']) |
| 614 | + |
| 615 | + |
| 616 | +def validate_task2_final_private(zip_file): |
| 617 | + return validate_zip_content(zip_file, |
| 618 | + required_files=['private_activity.txt', 'attack_output.zip', 'model.py', |
| 619 | + 'task2_inference.py', 'model.pth']) |
| 620 | + |
| 621 | + |
| 622 | +CLIENT_CHECKERS = { |
| 623 | + 'final_nlp_public': validate_task1_public_test, |
| 624 | + 'final_nlp_public_final': validate_task1_final_public, |
| 625 | + 'final_nlp_private': validate_task1_private_test, |
| 626 | + 'final_nlp_private_final': validate_task1_final_private, |
| 627 | + 'final_cv_public': validate_task2_public_test, |
| 628 | + 'final_cv_public_final': validate_task2_final_public, |
| 629 | + 'final_cv_private': validate_task2_private_test, |
| 630 | + 'final_cv_private_final': validate_task2_final_private, |
| 631 | +} |
| 632 | + |
| 633 | + |
556 | 634 | class ProblemSubmit(LoginRequiredMixin, ProblemMixin, TitleMixin, SingleObjectFormView): |
557 | 635 | template_name = 'problem/submit.html' |
558 | 636 | form_class = ProblemSubmitForm |
@@ -648,6 +726,20 @@ def form_valid(self, form): |
648 | 726 | return generic_message(self.request, _('Banned from submitting'), |
649 | 727 | _('You have been declared persona non grata for this problem. ' |
650 | 728 | 'You are permanently barred from submitting to this problem.')) |
| 729 | + |
| 730 | + # Check for client checkers |
| 731 | + submission_file = form.files.get('submission_file', None) |
| 732 | + if self.object.code in CLIENT_CHECKERS: |
| 733 | + checker = CLIENT_CHECKERS[self.object.code] |
| 734 | + if not submission_file: |
| 735 | + form.add_error('submission_file', _('This problem requires a file submission.')) |
| 736 | + return self.form_invalid(form) |
| 737 | + |
| 738 | + error_msg = checker(submission_file) |
| 739 | + if error_msg: |
| 740 | + form.add_error('submission_file', error_msg) |
| 741 | + return self.form_invalid(form) |
| 742 | + |
651 | 743 | # Must check for zero and not None. None means infinite submissions remaining. |
652 | 744 | if self.remaining_submission_count == 0: |
653 | 745 | return generic_message(self.request, _('Too many submissions'), |
|
0 commit comments