Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
prohibit direct inserts outside populate for auto-populated tables
  • Loading branch information
dimitri-yatsenko committed Oct 30, 2018
commit a162c0a418e13fcecaab4edd31ce8606672894ca
3 changes: 3 additions & 0 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AutoPopulate:
must define the property `key_source`, and must define the callback method `make`.
"""
_key_source = None
_allow_insert = False

@property
def key_source(self):
Expand Down Expand Up @@ -149,7 +150,9 @@ def handler(signum, frame):
logger.info('Populating: ' + str(key))
call_count += 1
try:
self._allow_insert = True
make(dict(key))
self._allow_insert = False
except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
Expand Down
14 changes: 13 additions & 1 deletion datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def insert1(self, row, **kwargs):
"""
self.insert((row,), **kwargs)

def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False):
def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False,
allow_direct_insert=None):
"""
Insert a collection of rows.

Expand All @@ -161,6 +162,7 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
:param replace: If True, replaces the existing tuple.
:param skip_duplicates: If True, silently skip duplicate inserts.
:param ignore_extra_fields: If False, fields that are not in the heading raise error.
:param allow_direct_insert: applies only in auto-populated tables. Set True to insert outside populate calls.

Example::
>>> relation.insert([
Expand All @@ -172,6 +174,16 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
warnings.warn('Use of `ignore_errors` in `insert` and `insert1` is deprecated. Use try...except... '
'to explicitly handle any errors', stacklevel=2)

# prohibit direct inserts into auto-populated tables
try:
allow = allow_direct_insert or self._allow_insert # only present in AutoPopulate
except AttributeError:
pass # for non-AutoPopulate tables
else:
if not allow:
raise DataJointError(
'Auto-populate tables can only be inserted into from their make methods during populate calls.')

heading = self.heading
if inspect.isclass(rows) and issubclass(rows, Query): # instantiate if a class
rows = rows()
Expand Down
2 changes: 1 addition & 1 deletion tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class DecimalPrimaryKey(dj.Lookup):
@schema
class IndexRich(dj.Manual):
definition = """
-> Experiment
-> Subject
---
-> [unique, nullable] User.proj(first="username")
first_date : date
Expand Down
25 changes: 20 additions & 5 deletions tests/test_autopopulate.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from nose.tools import assert_raises, assert_equal, \
assert_false, assert_true, assert_list_equal, \
assert_tuple_equal, assert_dict_equal, raises

from nose.tools import assert_equal, assert_false, assert_true, raises
from . import schema
from datajoint import DataJointError


class TestPopulate:
"""
Test base relations: insert, delete
"""

def __init__(self):
def setUp(self):
self.user = schema.User()
self.subject = schema.Subject()
self.experiment = schema.Experiment()
self.trial = schema.Trial()
self.ephys = schema.Ephys()
self.channel = schema.Ephys.Channel()

def tearDown(self):
# delete automatic tables just in case
self.channel.delete_quick()
self.ephys.delete_quick()
self.trial.Condition.delete_quick()
self.trial.delete_quick()
self.experiment.delete_quick()

Expand Down Expand Up @@ -49,3 +49,18 @@ def test_populate(self):
self.ephys.populate()
assert_true(self.ephys)
assert_true(self.channel)

def test_allow_direct_insert(self):
assert_true(self.subject, 'root tables are empty')
key = self.subject.fetch('KEY')[0]
key['experiment_id'] = 1000
key['experiment_date'] = '2018-10-30'
self.experiment.insert1(key, allow_direct_insert=True)

@raises(DataJointError)
def test_allow_insert(self):
assert_true(self.subject, 'root tables are empty')
key = self.subject.fetch('KEY')[0]
key['experiment_id'] = 1001
key['experiment_date'] = '2018-10-30'
self.experiment.insert1(key)
4 changes: 2 additions & 2 deletions tests/test_declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def test_dependencies():
assert_true(experiment.full_table_name in set(user.children(primary=False)))
assert_equal(set(experiment.parents(primary=False)), {user.full_table_name})

assert_equal(set(subject.children(primary=True)), {experiment.full_table_name})
assert_equal(set(experiment.parents(primary=True)), {subject.full_table_name})
assert_true(experiment.full_table_name in subject.descendants())
assert_true(subject.full_table_name in experiment.ancestors())

assert_true(trial.full_table_name in experiment.descendants())
assert_true(experiment.full_table_name in trial.ancestors())
Expand Down