diff --git a/datajoint/__init__.py b/datajoint/__init__.py index dffe9af36..eab54e89c 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -46,6 +46,8 @@ class key: config.add_history('No config file found, using default settings.') else: config.load(config_file) + del config_file + del config_files # override login credentials with environment variables @@ -58,6 +60,7 @@ class key: for k in mapping: config.add_history('Updated login credentials from %s' % k) config.update(mapping) +del mapping logger.setLevel(log_levels[config['loglevel']]) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 0cc2162fe..55c2d4437 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -22,6 +22,7 @@ class AutoPopulate: must define the property `key_source`, and must define the callback method `make`. """ _key_source = None + _allow_insert = False @property def key_source(self): @@ -148,6 +149,7 @@ def handler(signum, frame): else: logger.info('Populating: ' + str(key)) call_count += 1 + self._allow_insert = True try: make(dict(key)) except (KeyboardInterrupt, SystemExit, Exception) as error: @@ -172,6 +174,8 @@ def handler(signum, frame): self.connection.commit_transaction() if reserve_jobs: jobs.complete(self.target.table_name, self._job_key(key)) + finally: + self._allow_insert = False # place back the original signal handler if reserve_jobs: diff --git a/datajoint/table.py b/datajoint/table.py index fb17f600d..188b52e61 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -152,7 +152,8 @@ def insert1(self, row, **kwargs): """ self.insert((row,), **kwargs) - def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False): + def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False, + allow_direct_insert=None): """ Insert a collection of rows. @@ -161,6 +162,7 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields :param replace: If True, replaces the existing tuple. :param skip_duplicates: If True, silently skip duplicate inserts. :param ignore_extra_fields: If False, fields that are not in the heading raise error. + :param allow_direct_insert: applies only in auto-populated tables. Set True to insert outside populate calls. Example:: >>> relation.insert([ @@ -172,6 +174,11 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields warnings.warn('Use of `ignore_errors` in `insert` and `insert1` is deprecated. Use try...except... ' 'to explicitly handle any errors', stacklevel=2) + # prohibit direct inserts into auto-populated tables + if not (allow_direct_insert or getattr(self, '_allow_insert', True)): # _allow_insert is only present in AutoPopulate + raise DataJointError( + 'Auto-populate tables can only be inserted into from their make methods during populate calls.') + heading = self.heading if inspect.isclass(rows) and issubclass(rows, Query): # instantiate if a class rows = rows() diff --git a/tests/schema.py b/tests/schema.py index 255e32ee8..7610a566a 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -275,7 +275,7 @@ class DecimalPrimaryKey(dj.Lookup): @schema class IndexRich(dj.Manual): definition = """ - -> Experiment + -> Subject --- -> [unique, nullable] User.proj(first="username") first_date : date diff --git a/tests/test_autopopulate.py b/tests/test_autopopulate.py index 83da52e24..d3a382b8c 100644 --- a/tests/test_autopopulate.py +++ b/tests/test_autopopulate.py @@ -1,8 +1,6 @@ -from nose.tools import assert_raises, assert_equal, \ - assert_false, assert_true, assert_list_equal, \ - assert_tuple_equal, assert_dict_equal, raises - +from nose.tools import assert_equal, assert_false, assert_true, raises from . import schema +from datajoint import DataJointError class TestPopulate: @@ -10,7 +8,7 @@ class TestPopulate: Test base relations: insert, delete """ - def __init__(self): + def setUp(self): self.user = schema.User() self.subject = schema.Subject() self.experiment = schema.Experiment() @@ -18,9 +16,11 @@ def __init__(self): self.ephys = schema.Ephys() self.channel = schema.Ephys.Channel() + def tearDown(self): # delete automatic tables just in case self.channel.delete_quick() self.ephys.delete_quick() + self.trial.Condition.delete_quick() self.trial.delete_quick() self.experiment.delete_quick() @@ -49,3 +49,18 @@ def test_populate(self): self.ephys.populate() assert_true(self.ephys) assert_true(self.channel) + + def test_allow_direct_insert(self): + assert_true(self.subject, 'root tables are empty') + key = self.subject.fetch('KEY')[0] + key['experiment_id'] = 1000 + key['experiment_date'] = '2018-10-30' + self.experiment.insert1(key, allow_direct_insert=True) + + @raises(DataJointError) + def test_allow_insert(self): + assert_true(self.subject, 'root tables are empty') + key = self.subject.fetch('KEY')[0] + key['experiment_id'] = 1001 + key['experiment_date'] = '2018-10-30' + self.experiment.insert1(key) diff --git a/tests/test_declare.py b/tests/test_declare.py index d6d87619e..d521dc688 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -105,8 +105,8 @@ def test_dependencies(): assert_true(experiment.full_table_name in set(user.children(primary=False))) assert_equal(set(experiment.parents(primary=False)), {user.full_table_name}) - assert_equal(set(subject.children(primary=True)), {experiment.full_table_name}) - assert_equal(set(experiment.parents(primary=True)), {subject.full_table_name}) + assert_true(experiment.full_table_name in subject.descendants()) + assert_true(subject.full_table_name in experiment.ancestors()) assert_true(trial.full_table_name in experiment.descendants()) assert_true(experiment.full_table_name in trial.ancestors())