diff --git a/.gitignore b/.gitignore index fc01e66..f7e05bb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ *.egg-info dist/ build/ +.idea/ +.vscode/tags + diff --git a/mongorm/BaseDocument.py b/mongorm/BaseDocument.py index 70ce56f..1b2e772 100644 --- a/mongorm/BaseDocument.py +++ b/mongorm/BaseDocument.py @@ -30,9 +30,23 @@ def _fromMongo( self, data, overwrite=True ): setattr(self, name, pythonValue) return self - + + # The following three methods are used for pickling/unpickling. + # If it weren't for the fact that __getattr__ returns None + # for non-existing attributes (rather than raising an AttributeError), + # we would not need to define them. + + def __getstate__( self ): + return self.__dict__ + + def __setstate__( self, state ): + self.__dict__.update( state ) + + def __getnewargs__( self ): + return ( ) + def __setattr__( self, name, value ): - assert (name[0] == '_' and hasattr(self, name)) or name in self._fields, \ + assert name[0] == '_' or (name in self._fields), \ "Field '%s' does not exist in document '%s'" \ % (name, self.__class__.__name__) @@ -51,21 +65,29 @@ def __setattr__( self, name, value ): def __getattr__( self, name ): if name not in self._values and self._is_lazy and \ '_id' in self._data and self._data['_id'] is not None: + + if name not in self._fields: + raise AttributeError + # field is being accessed and the object is currently in lazy mode # may need to retrieve rest of document field = self._fields[name] if field.dbField not in self._data: # field not retrieved from database! load whole document. weeee result = connection.getDatabase( )[self._collection].find_one( { '_id': self._data['_id'] } ) + if result is None: + raise self.DoesNotExist self._fromMongo( result, overwrite=False ) self._is_lazy = False - default = None field = self._fields.get( name, None ) - if field is not None: - default = field.getDefault( ) + if not name in self._values: + default = None + if field is not None: + default = field.getDefault( ) + self._values[name] = default value = self._values.get( name ) @@ -75,7 +97,8 @@ def __getattr__( self, name ): def _resyncFromPython( self ): # before we go any further, re-sync from python values where needed for (name,field) in self._fields.iteritems( ): - if field._resyncAtSave: + requiresDefaultCall = (name not in self._values and callable(field.default)) + if field._resyncAtSave or requiresDefaultCall: dbField = field.dbField pythonValue = getattr(self, name) self._data[dbField] = field.fromPython( pythonValue ) diff --git a/mongorm/Document.py b/mongorm/Document.py index 63f77af..2bad2df 100644 --- a/mongorm/Document.py +++ b/mongorm/Document.py @@ -1,4 +1,5 @@ import pymongo +import warnings from mongorm.BaseDocument import BaseDocument from mongorm.connection import getDatabase @@ -22,7 +23,7 @@ def __ne__(self, other): def __init__( self, **kwargs ): super(Document, self).__init__( **kwargs ) - def save( self, forceInsert=False, safe=True ): + def save( self, forceInsert=False, **kwargs ): database = getDatabase( ) collection = database[self._collection] @@ -30,11 +31,18 @@ def save( self, forceInsert=False, safe=True ): if '_id' in self._data and self._data['_id'] is None: del self._data['_id'] + + # safe not supported in pymongo 3.0+, use w for write concern instead + if 'safe' in kwargs: + kwargs['w'] = 1 if kwargs['safe'] else 0 + del kwargs['safe'] + warnings.warn('{} safe not supported in pymongo 3.0+, use w for write concern instead'.format(collection.full_name), DeprecationWarning) + try: if forceInsert: - newId = collection.insert( self._data, safe=safe ) + newId = collection.insert( self._data, **kwargs ) else: - newId = collection.save( self._data, safe=safe ) + newId = collection.save( self._data, **kwargs ) except pymongo.errors.OperationFailure, err: message = 'Could not save document (%s)' if u'duplicate key' in unicode(err): diff --git a/mongorm/DocumentMetaclass.py b/mongorm/DocumentMetaclass.py index 04cab4b..7588163 100644 --- a/mongorm/DocumentMetaclass.py +++ b/mongorm/DocumentMetaclass.py @@ -1,3 +1,4 @@ + from mongorm.fields.BaseField import BaseField from mongorm.queryset.QuerySetManager import QuerySetManager from mongorm.DocumentRegistry import DocumentRegistry @@ -7,7 +8,7 @@ from mongorm.errors import DoesNotExist, MultipleObjectsReturned -from mongorm.connection import getDatabase +from mongorm import connection import sys @@ -58,11 +59,19 @@ def __new__( cls, name, bases, attrs ): field.dbField = attrName fields[attrName] = field del attrs[attrName] - + + def indexConverter( fieldName ): + if fieldName in fields: + return fields[fieldName].optimalIndex( ) + return fieldName + for field,value in fields.iteritems( ): if value.primaryKey: assert primaryKey is None, "Can only have one primary key per document" primaryKey = field + if value.unique: + keyOpts = { 'unique': True } + connection.stagedIndexes.append( (collection, [(indexConverter( field ), 1)], keyOpts) ) # add a primary key if none exists and one is required if needsPrimaryKey and primaryKey is None: @@ -77,19 +86,13 @@ def __new__( cls, name, bases, attrs ): if 'indexes' in meta: indexes = meta['indexes'] - _database = getDatabase( ) - _collection = _database[collection] - for index in indexes: if not isinstance(index, (list,tuple)): index = [index] - def indexConverter( fieldName ): - if fieldName in fields: - return fields[fieldName].optimalIndex( ) - return fieldName + pyMongoIndexKeys = sortListToPyMongo( index, indexConverter ) - _collection.ensure_index( pyMongoIndexKeys ) - + connection.stagedIndexes.append( (collection, pyMongoIndexKeys, {}) ) + # add a query set manager if none exists already if 'objects' not in attrs: attrs['objects'] = QuerySetManager( ) @@ -114,12 +117,15 @@ def indexConverter( fieldName ): newClass._addException( 'MultipleObjectsReturned', bases, defaultBase=MultipleObjectsReturned, module=module ) - + # register the document for name-based reference DocumentRegistry.registerDocument( name, newClass ) + + if connection.database is not None: + connection.ensureIndexes() return newClass - + def _addException( self, name, bases, defaultBase, module ): baseExceptions = tuple( getattr(base, name) \ for base in bases if hasattr(base, name) diff --git a/mongorm/__init__.py b/mongorm/__init__.py index 297a567..33dd4c6 100644 --- a/mongorm/__init__.py +++ b/mongorm/__init__.py @@ -5,6 +5,7 @@ from mongorm.EmbeddedDocument import EmbeddedDocument from mongorm.fields.DictField import DictField +from mongorm.fields.SafeDictField import SafeDictField from mongorm.fields.IntegerField import IntegerField, IntField from mongorm.fields.ObjectIdField import ObjectIdField from mongorm.fields.StringField import StringField diff --git a/mongorm/connection.py b/mongorm/connection.py index 1a1b2fb..cf933d9 100644 --- a/mongorm/connection.py +++ b/mongorm/connection.py @@ -1,45 +1,117 @@ -from pymongo import Connection - +from pymongo import MongoClient, MongoReplicaSetClient, IndexModel +from pymongo.collection import Collection connection = None database = None connectionSettings = None -def connect( database, **kwargs ): - global connectionSettings - +pymongoWrapper = None + +stagedIndexes = [] +droppedIndexes = [] + +def connect( databaseName, autoEnsure=False, wrapPymongo=None, **kwargs ): + global database, connection, connectionSettings, autoEnsureIndexes, pymongoWrapper + + pymongoWrapper = wrapPymongo + + autoEnsureIndexes = autoEnsure + connectionSettings = {} connectionSettings.update( kwargs ) connectionSettings.update( { - 'database': database + 'database': databaseName } ) + # Reset database & connection + connection = None + database = None + + # Initialise connection and ensure indexes if configured + getDatabase() + def getConnection( ): global database, connection, connectionSettings - + assert connectionSettings is not None, "No database specified: call mongorm.connect() before use" - + if connection is None: connectionArgs = {} - - for key in ['host', 'port']: + + # read_preference Not supported in pymongo 3.0+. + # it should be an option on get_database, get_collection, with_options + for key in ['host', 'port', 'replicaSet', 'username', 'password']: if key in connectionSettings: connectionArgs[key] = connectionSettings[key] - - connection = Connection( **connectionArgs ) - + + client = MongoReplicaSetClient if 'replicaSet' in connectionArgs else MongoClient + + connection = client( **connectionArgs ) + return connection +def index_name_from_index_fields(index_fields): + return '_'.join([ + field + '_1' if direction is 1 else '_-1' + for (field, direction) in index_fields + ]) + +def ensureIndexes( ): + global stagedIndexes, droppedIndexes + + if not autoEnsureIndexes: + return + + assert database is not None, "Must be connected to database before ensuring indexes" + + # Ensure indexes on the documents + indexes_by_collection = {} + for collection, key_or_list, kwargs in stagedIndexes: + if collection not in indexes_by_collection: + indexes_by_collection[collection] = [] + + indexes_by_collection[collection].append((key_or_list, kwargs)) + + for collection in indexes_by_collection: + # Create the collection if necessary + if collection not in database.collection_names(): + Collection(database, collection, create=True) + + indexInfo = database[collection].index_information() + dropped_indexes = [] + + for index_fields, kwargs in indexes_by_collection[collection]: + expected_index_name = index_name_from_index_fields(index_fields) + + existing_index = indexInfo.get(expected_index_name) + if existing_index is not None: + hasChanged = False + if kwargs.get('unique', False) or existing_index.get('unique', False): + if kwargs.get('unique', False) != existing_index.get('unique', False): + hasChanged = True + + if hasChanged: + database[collection].drop_index(expected_index_name) + dropped_indexes.append(expected_index_name) + + database[collection].create_indexes([ + IndexModel(index_fields, **kwargs) + for index_fields, kwargs in indexes_by_collection[collection] + if ( + index_name_from_index_fields(index_fields) not in indexes_by_collection[collection] or + index_name_from_index_fields(index_fields) in dropped_indexes) + ]) + + stagedIndexes = [] + def getDatabase( ): global database, connectionSettings - + if database is None: connection = getConnection( ) databaseName = connectionSettings['database'] database = connection[databaseName] - - if 'username' in connectionSettings and \ - 'password' in connectionSettings: - database.authenticate( connectionSettings['username'], connectionSettings['password'] ) - - return database \ No newline at end of file + + ensureIndexes() + + return database diff --git a/mongorm/fields/BaseField.py b/mongorm/fields/BaseField.py index 5e151ef..9923acb 100644 --- a/mongorm/fields/BaseField.py +++ b/mongorm/fields/BaseField.py @@ -19,6 +19,8 @@ def toQuery( self, pythonValue, dereferences=[] ): return self.fromPython( pythonValue ) def getDefault( self ): + if callable(self.default): + return self.default() return self.default def setOwnerDocument( self, ownerDocument ): diff --git a/mongorm/fields/DateTimeField.py b/mongorm/fields/DateTimeField.py index a1080aa..8be8bd5 100644 --- a/mongorm/fields/DateTimeField.py +++ b/mongorm/fields/DateTimeField.py @@ -2,12 +2,43 @@ from datetime import datetime +try: + from pytz import utc +except ImportError: + PYTZ = False +else: + PYTZ = True + +try: + from iso8601 import parse_date, ParseError +except ImportError: + ISO8601 = False +else: + ISO8601 = True + class DateTimeField(BaseField): def fromPython( self, pythonValue, dereferences=[], modifier=None ): - if pythonValue is not None and not isinstance(pythonValue, datetime): - raise ValueError, "Value must be a datetime object not %r" % (pythonValue,) - + if ISO8601 and isinstance(pythonValue, basestring): + try: + pythonValue = parse_date( pythonValue ) + except ParseError: + # oh well we tried + pass + + if isinstance(pythonValue, datetime): + if PYTZ and pythonValue.tzinfo is not None: + pythonValue = pythonValue.astimezone( utc ) + pythonValue = pythonValue.replace( tzinfo=None ) + + # mongo doesn't handle microseconds + pythonValue = pythonValue.replace( microsecond=(pythonValue.microsecond//1000)*1000 ) + return pythonValue def toPython( self, bsonValue ): - return bsonValue \ No newline at end of file + if PYTZ and bsonValue is not None: + if bsonValue.tzinfo is None: + bsonValue = utc.localize( bsonValue ) + elif bsonValue.tzinfo != utc: + bsonValue = bsonValue.astimezone(utc) + return bsonValue diff --git a/mongorm/fields/DecimalField.py b/mongorm/fields/DecimalField.py index ad1eaf7..8f81cac 100644 --- a/mongorm/fields/DecimalField.py +++ b/mongorm/fields/DecimalField.py @@ -2,8 +2,6 @@ from decimal import Decimal -from datetime import datetime - class DecimalField(BaseField): def fromPython( self, pythonValue, dereferences=[], modifier=None ): if isinstance(pythonValue, (basestring, int, float)): @@ -13,4 +11,4 @@ def fromPython( self, pythonValue, dereferences=[], modifier=None ): return str(pythonValue) def toPython( self, bsonValue ): - return Decimal(bsonValue) \ No newline at end of file + return Decimal(bsonValue) diff --git a/mongorm/fields/ListField.py b/mongorm/fields/ListField.py index f7a859a..32d3a00 100644 --- a/mongorm/fields/ListField.py +++ b/mongorm/fields/ListField.py @@ -17,7 +17,7 @@ def toQuery( self, pythonValue, dereferences=[] ): return self.fromPython( pythonValue ) def fromPython( self, pythonValue, dereferences=[], modifier=None ): - if modifier in ('push', 'pop'): + if modifier in ('push', 'pull'): return self.itemClass.fromPython(pythonValue, dereferences) # list modifiers act on single instances return [ self.itemClass.fromPython(value, dereferences) for value in pythonValue ] @@ -31,4 +31,4 @@ def toPython( self, bsonValue ): def setOwnerDocument( self, ownerDocument ): super(ListField, self).setOwnerDocument( ownerDocument ) - self.itemClass.setOwnerDocument( ownerDocument ) \ No newline at end of file + self.itemClass.setOwnerDocument( ownerDocument ) diff --git a/mongorm/fields/ReferenceField.py b/mongorm/fields/ReferenceField.py index c09e33c..7bc27ae 100644 --- a/mongorm/fields/ReferenceField.py +++ b/mongorm/fields/ReferenceField.py @@ -1,7 +1,4 @@ -try: - from pymongo import objectid, dbref -except ImportError: - from bson import objectid, dbref +from bson import objectid, dbref import bson.errors from mongorm.fields.BaseField import BaseField @@ -12,7 +9,7 @@ class ReferenceField(BaseField): def __init__( self, documentClass, *args, **kwargs ): super(ReferenceField, self).__init__( *args, **kwargs ) - + self._use_ref_id = kwargs.get('use_ref_id', False) self.inputDocumentClass = documentClass def _getClassInfo( self ): @@ -37,7 +34,11 @@ def fromPython( self, pythonValue, dereferences=[], modifier=None ): if pythonValue is None: return None - if not isinstance(pythonValue, self.documentClass): + if isinstance(pythonValue, dbref.DBRef): + return { + '_ref': pythonValue + } + elif not isinstance(pythonValue, self.documentClass): # try mapping to an objectid try: objectId = objectid.ObjectId( str( pythonValue ) ) @@ -62,9 +63,15 @@ def fromPython( self, pythonValue, dereferences=[], modifier=None ): def toQuery( self, pythonValue, dereferences=[] ): if pythonValue is None: return None - return { - '_ref': self.fromPython( pythonValue )['_ref'] - } + # Note: this is only specific for cosmosdb which doesn't support dbref + if self._use_ref_id: + return { + '_ref.$id': self.fromPython( pythonValue )['_ref'].id + } + else: + return { + '_ref': self.fromPython( pythonValue )['_ref'] + } def toPython( self, bsonValue ): self._getClassInfo( ) @@ -72,6 +79,8 @@ def toPython( self, bsonValue ): if bsonValue is None: return None + documentClass = None + if isinstance(bsonValue, dbref.DBRef): # old style (mongoengine) dbRef = bsonValue @@ -86,9 +95,13 @@ def toPython( self, bsonValue ): if '_cls' in bsonValue: # mongoengine GenericReferenceField compatibility documentName = bsonValue['_cls'] - else: + elif '_types' in bsonValue: documentName = bsonValue['_types'][0] + else: + return dbRef + documentClass = DocumentRegistry.getDocument( documentName ) + initialData = { '_id': dbRef.id, } @@ -97,4 +110,4 @@ def toPython( self, bsonValue ): return documentClass( )._fromMongo( initialData ) def optimalIndex( self ): - return self.dbField + '._ref' \ No newline at end of file + return self.dbField + '._ref' diff --git a/mongorm/fields/SafeDictField.py b/mongorm/fields/SafeDictField.py new file mode 100644 index 0000000..a9e3185 --- /dev/null +++ b/mongorm/fields/SafeDictField.py @@ -0,0 +1,40 @@ +from mongorm.fields.DictField import DictField + +from collections import deque +from operator import methodcaller +from copy import deepcopy + +def deepCoded( dictionary, coder ): + dictionary = deepcopy( dictionary ) # leave the original intact + toCode = deque( [dictionary] ) + while toCode: + nextDictionary = toCode.popleft( ) + for key, value in nextDictionary.items( ): # can't be iteritems as we're changing the dict + if isinstance(key, basestring): + # Keys have to be strings in mongo so this should always occur + del nextDictionary[key] + nextDictionary[coder( key )] = value + if isinstance(value, dict): + toCode.append( value ) + return dictionary + +def encode( string ): + if isinstance(string, unicode): + string = string.encode( 'utf-8' ) + return string.encode( 'hex' ) + +def decode( string ): + return string.decode( 'hex' ).decode( 'utf-8' ) + +class SafeDictField(DictField): + def fromPython( self, *args, **kwargs ): + result = super(SafeDictField, self).fromPython( *args, **kwargs ) + return deepCoded( result, encode ) + + def toPython( self, *args, **kwargs ): + result = super(SafeDictField, self).toPython( *args, **kwargs ) + return deepCoded( result, decode ) + + def toQuery( self, pythonValue, dereferences=[] ): + encodedDereferences = [encode( dereference ) for dereference in dereferences] + return super(SafeDictField, self).toQuery( pythonValue, encodedDereferences ) diff --git a/mongorm/fields/StringField.py b/mongorm/fields/StringField.py index 352678b..0c328f8 100644 --- a/mongorm/fields/StringField.py +++ b/mongorm/fields/StringField.py @@ -2,8 +2,11 @@ class StringField(BaseField): def fromPython( self, pythonValue, dereferences=[], modifier=None ): - return unicode(pythonValue) + if pythonValue is not None: + pythonValue = unicode(pythonValue) + return pythonValue def toPython( self, bsonValue ): - return unicode(bsonValue) - \ No newline at end of file + if bsonValue is not None: + bsonValue = unicode(bsonValue) + return bsonValue diff --git a/mongorm/queryset/Q.py b/mongorm/queryset/Q.py index 5010f85..62cea2e 100644 --- a/mongorm/queryset/Q.py +++ b/mongorm/queryset/Q.py @@ -17,10 +17,14 @@ def toMongo( self, document, forUpdate=False, modifier=None ): # mongodb logic operator - value is a list of Qs newSearch[name] = [ value.toMongo( document ) for value in value ] continue + + if name.startswith('$') and isinstance(value, basestring): + newSearch[name] = value + continue fieldName = name - MONGO_COMPARISONS = ['gt', 'lt', 'lte', 'gte', 'exists', 'ne', 'all', 'in', 'elemMatch'] + MONGO_COMPARISONS = ['gt', 'lt', 'lte', 'gte', 'exists', 'ne', 'all', 'in', 'nin', 'elemMatch'] REGEX_COMPARISONS = { 'contains': ( '%s', '' ), 'icontains': ( '%s', 'i' ), @@ -37,6 +41,7 @@ def toMongo( self, document, forUpdate=False, modifier=None ): 'imatches': ( None, 'i' ), } ALL_COMPARISONS = MONGO_COMPARISONS + REGEX_COMPARISONS.keys() + ARRAY_VALUE_COMPARISONS = ['all', 'in', 'nin'] comparison = None dereferences = [] @@ -58,10 +63,32 @@ def toMongo( self, document, forUpdate=False, modifier=None ): field = document._fields[fieldName] if not forUpdate: - searchValue = field.toQuery( value, dereferences=dereferences ) + if comparison in ARRAY_VALUE_COMPARISONS: + searchValues = [field.toQuery( item, dereferences=dereferences ) for item in value] + if searchValues and isinstance(searchValues[0], dict): + searchValue = {} + for dictValue in searchValues: + for key in dictValue: + # using setdefault instead of defaultdict in case python < 2.5 + searchValue.setdefault(key, []).append( dictValue[key] ) + else: + searchValue = searchValues + else: + searchValue = field.toQuery( value, dereferences=dereferences ) targetSearchKey = field.dbField else: - searchValue = field.fromPython( value, dereferences=dereferences, modifier=modifier ) + if comparison in ARRAY_VALUE_COMPARISONS: + searchValues = [field.fromPython( item, dereferences=dereferences, modifier=modifier ) for item in value] + if searchValues and isinstance(searchValues[0], dict): + searchValue = {} + for dictValue in searchValues: + for key in dictValue: + # using setdefault instead of defaultdict in case python < 2.5 + searchValue.setdefault(key, []).append( dictValue[key] ) + else: + searchValue = searchValues + else: + searchValue = field.fromPython( value, dereferences=dereferences, modifier=modifier ) targetSearchKey = '.'.join( [field.dbField] + dereferences) valueMapper = lambda value: value @@ -89,7 +116,10 @@ def toMongo( self, document, forUpdate=False, modifier=None ): if isinstance(searchValue, dict): if not forUpdate: for name,value in searchValue.iteritems( ): - key = targetSearchKey + '.' + name + if name: + key = targetSearchKey + '.' + name + else: + key = targetSearchKey newSearch[key] = valueMapper(value) else: newSearch[targetSearchKey] = valueMapper(searchValue) @@ -116,12 +146,12 @@ def do_merge( self, other, op ): if len(self.query) == 0: return other if len(other.query) == 0: return self - if op in self.query: + if op in self.query and len(self.query) == 1: items = self.query[op] + [other] - elif op in other.query: + elif op in other.query and len(self.query) == 1: items = other.query[op] + [self] else: items = [ self, other ] newQuery = { op: items } - return Q( _query=newQuery ) \ No newline at end of file + return Q( _query=newQuery ) diff --git a/mongorm/queryset/QuerySet.py b/mongorm/queryset/QuerySet.py index 60377d4..03450c6 100644 --- a/mongorm/queryset/QuerySet.py +++ b/mongorm/queryset/QuerySet.py @@ -6,109 +6,149 @@ from mongorm.blackMagic import serialiseTypesForDocumentType +PROJECTIONS = frozenset(['slice']) + class QuerySet(object): - def __init__( self, document, collection, query=None, orderBy=None, onlyFields=None ): + def __init__( self, document, collection, query=None, orderBy=None, fields=None, timeout=True, readPref=None, types=None ): self.document = document self.documentTypes = serialiseTypesForDocumentType( document ) self.collection = collection self.orderBy = [] - self.onlyFields = onlyFields + self._fields = fields + self.timeout = timeout + self.readPref = readPref + self.types = [] if orderBy is not None: self.orderBy = orderBy[:] self._savedCount = None self._savedItems = None - self._savedBuiltItems = None if query is None: self.query = Q( ) else: self.query = query - + if types: + for subclass in types: + if not issubclass(subclass, self.document): + raise TypeError, "'%s' is not a subclass of '%s'" % (subclass, self.document) + self.types.append( subclass ) + def _getNewInstance( self, data ): documentName = data.get( '_types', [self.document.__name__] )[0] documentClass = DocumentRegistry.getDocument( documentName ) assert issubclass( documentClass, self.document ) return documentClass( )._fromMongo( data ) - + + def _get_kwargs( self ): + return { + 'query': self.query, + 'orderBy': self.orderBy, + 'fields': self._fields, + 'timeout': self.timeout, + 'readPref': self.readPref, + 'types': self.types, + } + def get( self, query=None, **search ): if query is None: query = Q( **search ) - newQuery = self.query & query + self.query &= query #self._mergeSearch( search ) #print 'get:', newQuery.toMongo( self.document ) - + # limit of 2 so we know if multiple matched without running a count() - result = list( self.collection.find( newQuery.toMongo( self.document ), limit=2 ) ) - + result = list( self._do_find( limit=2 ) ) + if len(result) == 0: raise self.document.DoesNotExist( ) - + if len(result) == 2: raise self.document.MultipleObjectsReturned( ) - + return self._getNewInstance( result[0] ) - + def all( self ): return self - + + def close( self ): + if self._savedItems: + self._savedItems.close() + def filter( self, query=None, **search ): if query is None: query = Q( **search ) - newQuery = self.query & query - #print 'filter:', newQuery.toMongo( self.document ) - return QuerySet( self.document, self.collection, query=newQuery, orderBy=self.orderBy, onlyFields=self.onlyFields ) - + kwargs = self._get_kwargs( ) + kwargs['query'] &= query + return QuerySet( self.document, self.collection, **kwargs ) + + def no_timeout( self ): + kwargs = self._get_kwargs( ) + kwargs['timeout'] = False + return QuerySet( self.document, self.collection, **kwargs ) + + def read_preference( self, readPref ): + kwargs = self._get_kwargs( ) + kwargs['readPref'] = readPref + return QuerySet( self.document, self.collection, **kwargs ) + + def subtypes( self, *types ): + kwargs = self._get_kwargs( ) + kwargs['types'] = types + return QuerySet( self.document, self.collection, **kwargs ) + def count( self ): if self._savedCount is None: if self._savedItems is None: - self._savedCount = self.collection.find( self._get_query( ) ).count( ) + self._savedCount = self.collection.count( self._get_query( ) ) else: self._savedCount = self._savedItems.count( ) - + return self._savedCount - + def __len__( self ): return self.count( ) - + def delete( self ): self.collection.remove( self.query.toMongo( self.document ) ) - + def _prepareActions( self, **actions ): updates = {} - + for action, value in actions.iteritems( ): assert '__' in action, 'Action "%s" not legal for update' % (action,) modifier, fieldName = action.split( '__', 1 ) - assert modifier in ['set', 'inc', 'dec', 'push', 'pushAll'], 'Unknown modifier "%s"' % modifier - + assert modifier in ['set', 'unset', 'setOnInsert', 'inc', 'dec', 'push', 'pushAll', 'pull', 'pullAll'], 'Unknown modifier "%s"' % modifier + if '$'+modifier not in updates: updates['$'+modifier] = {} - + translatedName = fieldName.replace('__', '.') - + mongoValues = Q( { fieldName: value } ).toMongo( self.document, forUpdate=True, modifier=modifier ) #print mongoValues mongoValue = mongoValues[translatedName] - + updates['$'+modifier].update( { translatedName: mongoValue } ) - + return updates - + def update( self, upsert=False, safeUpdate=False, modifyAndReturn=False, returnAfterUpdate=False, updateAllDocuments=False, **actions ): """Performs an update on the collection, using MongoDB atomic modifiers. - + If upsert is specified, the document will be created if it doesn't exist. + + DEPRECATED: If safeUpdate is specified, the success of the update will be checked and the number of modified documents will be returned. - + If modifyAndReturn is specified, a findAndModify operation will be executed instead of an update operation. The *original* document instance (before any - modifications) will be returned, unless returnAfterUpdate is True. If no + modifications) will be returned, unless returnAfterUpdate is True. If no document matched the specified query, None will be returned.""" - + updates = self._prepareActions( **actions ) - + # XXX: why was this here? we shouldn't be forcing this #if '$set' not in updates: # updates['$set'] = {} @@ -117,23 +157,25 @@ def update( self, upsert=False, safeUpdate=False, modifyAndReturn=False, returnA #print 'query:', self.query.toMongo( self.document ) #print 'update:', updates - + query = self._get_query( forUpsert=True ) #print query, 'query' #print updates, 'update' - + # {'_types': {$all:['BaseThingUpsert']}, 'name': 'upsert1'} # {'$set': {'value': 42}, '$addToSet': {'_types': {$each: ['BTI', 'BaseThingUpsert']}}} - + updates['$addToSet'] = { '_types': { '$each': self.documentTypes } } - + if not modifyAndReturn: # standard 'update' - ret = self.collection.update( query, updates, upsert=upsert, safe=safeUpdate, multi=updateAllDocuments ) + # safe not supported in pymongo 3.0+, use w for write concern instead + w = 1 if safeUpdate else 0 + ret = self.collection.update( query, updates, upsert=upsert, w=w, multi=updateAllDocuments ) if ret is None: return None if 'n' in ret: @@ -146,61 +188,98 @@ def update( self, upsert=False, safeUpdate=False, modifyAndReturn=False, returnA upsert=upsert, new=returnAfterUpdate, ) - - if len(result) == 0: + + if result is None or len(result) == 0: return None else: return self._getNewInstance( result ) - + def order_by( self, *fields ): + kwargs = self._get_kwargs( ) newOrderBy = self.orderBy[:] newOrderBy.extend( fields ) - return QuerySet( self.document, self.collection, query=self.query, orderBy=newOrderBy, onlyFields=self.onlyFields ) - + kwargs['orderBy'] = newOrderBy + return QuerySet( self.document, self.collection, **kwargs ) + def only( self, *fields ): - onlyFields = set(fields) - return QuerySet( self.document, self.collection, query=self.query, orderBy=self.orderBy, onlyFields=onlyFields ) - + kwargs = self._get_kwargs( ) + kwargs['fields'] = dict(self._fields or {}, **dict.fromkeys( fields, True )) + return QuerySet( self.document, self.collection, **kwargs ) + def ignore( self, *fields ): - current = set(self.document._fields.keys()) - if self.onlyFields is not None: - current = set(self.onlyFields) - return self.only( *list(current - set(fields)) ) - + kwargs = self._get_kwargs( ) + kwargs['fields'] = dict(self._fields or {}, **dict.fromkeys( fields, False )) + return QuerySet( self.document, self.collection, **kwargs ) + + def fields( self, **projections ): + kwargs = self._get_kwargs( ) + kwargs['fields'] = dict(self._fields or {}) + for field, value in projections.iteritems( ): + if '__' in field: + fieldName, sep, projection = field.rpartition( '__' ) + if projection in PROJECTIONS: + field = fieldName + value = {'$%s' % projection: value} + kwargs['fields'][field] = value + return QuerySet( self.document, self.collection, **kwargs ) + def _do_find( self, **kwargs ): if 'sort' not in kwargs: sorting = sortListToPyMongo( self.orderBy ) - + if len(sorting) > 0: kwargs['sort'] = sorting - - if self.onlyFields is not None: - kwargs['fields'] = self.onlyFields - + + # fields not supported in pymongo 3.0+, use projection instead + if 'fields' in kwargs: + kwargs['projection'] = kwargs['fields'] + del kwargs['fields'] + elif self._fields is not None: + kwargs['projection'] = self._fields + + # timeout not supported in pymongo 3.0+, use no_cursor_timeout instead + if 'timeout' in kwargs: + kwargs['no_cursor_timeout'] = not kwargs['timeout'] + else: + kwargs['no_cursor_timeout'] = not self.timeout + search = self._get_query( ) - return self.collection.find( search, **kwargs ) - + + if '_types' in search and 'projection' in kwargs and not kwargs['projection'].get( '_types' ) and all(kwargs['projection'].itervalues( )): + kwargs['projection']['_types'] = True + + # read_preference not supported in pymongo 3.0+, use with_options() instead + if 'read_preference' in kwargs: + read_preference = kwargs['read_preference'] + del kwargs['read_preference'] + else: + read_preference = self.readPref + + if read_preference: + collection = self.collection.with_options(read_preference=read_preference) + else: + collection = self.collection + + return collection.find( search, **kwargs ) + def _get_query( self, forUpsert=False ): search = self.query.toMongo( self.document ) types = self.documentTypes - if len(types) > 1: # only filter when looking at a subclass + if self.types: + search['_types'] = {'$in': [subtype.__name__ for subtype in self.types]} + elif len(types) > 1: # only filter when looking at a subclass if forUpsert: search['_types'] = {'$all':[self.document.__name__]} # filter by the type that was used else: search['_types'] = self.document.__name__ # filter by the type that was used return search - + def __iter__( self ): #print 'iter:', self.query.toMongo( self.document ), self.collection if self._savedItems is None: self._savedItems = self._do_find( ) - self._savedBuiltItems = [] - for i,item in enumerate(self._savedItems): - if i >= len(self._savedBuiltItems): - self._savedBuiltItems.append( self._getNewInstance( item ) ) - - yield self._savedBuiltItems[i] - + return (self._getNewInstance( item ) for item in self._savedItems.clone( )) + def __getitem__( self, index ): if isinstance(index, int): getOne = True @@ -209,15 +288,15 @@ def __getitem__( self, index ): elif isinstance(index, slice): getOne = False skip = index.start or 0 - limit = index.stop - skip + limit = index.stop - skip if index.stop is not None else 0 assert index.step is None, "Slicing with step not supported by mongorm" else: assert False, "item not an index" - + #print self.query.toMongo( self.document ) #items = self.collection.find( self.query.toMongo( self.document ), skip=skip, limit=limit ) items = self._do_find( skip=skip, limit=limit ) - + if getOne: try: item = items[0] @@ -231,36 +310,36 @@ def _yieldItems(): document = self._getNewInstance( item ) yield document return _yieldItems( ) - + def first( self ): try: return self[0] except IndexError: return None - + def __call__( self, **search ): return self.filter( **search ) - + def ensure_indexed( self ): """Ensures that the most optimal index for the query so far is actually in the database. - + Call this whenever a query is deemed expensive.""" - + indexKeys = [] - + indexKeys.extend( self._queryToIndex( self.query.toMongo( self.document ) ) ) - + indexKeys.extend( sortListToPyMongo( self.orderBy ) ) - + uniqueKeys = [] for key in indexKeys: if key not in uniqueKeys: uniqueKeys.append( key ) - + self.collection.ensure_index( uniqueKeys ) - + return self - + def _queryToIndex( self, query ): for key, value in query.iteritems( ): if key in ('$and', '$or'): @@ -270,4 +349,4 @@ def _queryToIndex( self, query ): elif key.startswith( '$' ): continue # skip, it's a mongo operator and we can't search it? else: - yield (key, pymongo.ASCENDING) # FIXME: work out direction better? \ No newline at end of file + yield (key, pymongo.ASCENDING) # FIXME: work out direction better? diff --git a/mongorm/queryset/QuerySetManager.py b/mongorm/queryset/QuerySetManager.py index 182dc0b..6bf086f 100644 --- a/mongorm/queryset/QuerySetManager.py +++ b/mongorm/queryset/QuerySetManager.py @@ -4,15 +4,15 @@ class QuerySetManager(object): def __init__( self ): self.collection = None - + def __get__( self, instance, owner ): if instance is not None: # Document class being accessed, not an object return self - + if self.collection is None: database = getDatabase( ) self.collection = database[owner._collection] - - return QuerySet( owner, self.collection ) - \ No newline at end of file + + from mongorm.connection import pymongoWrapper + return QuerySet( owner, pymongoWrapper(self.collection) ) diff --git a/setup.py b/setup.py index ca6ae6d..c7ececc 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,15 @@ from setuptools import setup, find_packages setup(name='mongorm', - version='0.1', + version='0.5.9.4', packages=find_packages(), - author='Theo Julienne', - author_email='theo@icy.com.au', - url='http://www.icy.com.au/???', + author='Theo Julienne, John Garland', + author_email='theo@icy.com.au, john@openlearning.com', + url='https://github.com/OpenLearningNet/mongorm', license='MIT', include_package_data=True, description='Mongorm', long_description='Mongorm', platforms=['any'], - install_requires=['pymongo', 'pysignals'], + install_requires=['pymongo >= 2.4.2', 'pysignals'], ) diff --git a/tests/test_connect.py b/tests/test_connect.py new file mode 100644 index 0000000..a058c61 --- /dev/null +++ b/tests/test_connect.py @@ -0,0 +1,12 @@ +import mongorm + +def test_connect( ): + mongorm.connect( 'test_mongorm' ) + assert mongorm.connection.getDatabase( ).name == 'test_mongorm' + +def test_reconnect( ): + mongorm.connect( 'test_mongorm' ) + assert mongorm.connection.getDatabase( ).name == 'test_mongorm' + + mongorm.connect( 'test_mongorm_2' ) + assert mongorm.connection.getDatabase( ).name == 'test_mongorm_2' diff --git a/tests/test_equality.py b/tests/test_equality.py index e078dba..2bb1498 100644 --- a/tests/test_equality.py +++ b/tests/test_equality.py @@ -1,5 +1,23 @@ +# -*- coding: utf8 -*- + from mongorm import * -from pymongo.dbref import DBRef +from bson.dbref import DBRef + +try: + from pytz import timezone, UTC +except ImportError: + PYTZ = False +else: + PYTZ = True + +try: + from iso8601 import parse_date +except ImportError: + ISO8601 = False +else: + ISO8601 = True + +from datetime import datetime def teardown_module(module): DocumentRegistry.clear( ) @@ -8,20 +26,135 @@ def test_equality( ): """Tests to make sure comparisons work. Equality compares database identity, not value similarity.""" connect( 'test_mongorm' ) - + class TestDocument(Document): s = StringField( ) - + a = TestDocument( s="Hello" ) a.save( ) - + b = TestDocument( s="Hello" ) b.save( ) assert not (a == b) assert a != b - + c = TestDocument.objects.get(pk=a.id) assert c == a - assert not (c != a) \ No newline at end of file + assert not (c != a) + +def test_equality_with_none( ): + """Tests to make sure comparisons with None work.""" + connect( 'test_mongorm' ) + + class TestDocument(Document): + s = StringField( ) + + a = TestDocument( ) + a.save( ) + + assert a.s is None + + a.s = "" + a.save( ) + + assert a.s == "" + + a.s = None + a.save( ) + + assert a.s is None + +def test_equality_with_unicode( ): + """Tests to make sure comparisons with unicode work.""" + connect( 'test_mongorm' ) + + class TestDocument(Document): + s = StringField( ) + + a = TestDocument( s=u"déjà vu" ) + a.save( ) + + assert a.s == u"déjà vu" + assert a.s != "déjà vu" + assert a.s != "deja vu" + +def test_equality_with_datetime( ): + """Tests to make sure comparisons with datetime objects work.""" + connect( 'test_mongorm' ) + + class TestDateTime(Document): + timestamp = DateTimeField( ) + + # Get current UTC time to the nearest millisecond + now = datetime.utcnow( ) + now = now.replace( microsecond=(now.microsecond//1000)*1000 ) + + if PYTZ: + now = UTC.localize( now ) + tz = timezone( "Australia/Sydney" ) + now = now.astimezone( tz ) + now = tz.normalize( now ) + + t = TestDateTime( timestamp=now ) + t.save( ) + + assert t.timestamp == now + + if PYTZ: + now = now.astimezone( UTC ) + + t = TestDateTime.objects.get( pk=t.id ) + + assert t.timestamp == now + + try: + t.timestamp = now.isoformat( ) + t.save( ) + except ValueError: + assert not ISO8601 + else: + assert ISO8601 + + if ISO8601: + t = TestDateTime.objects.get( pk=t.id ) + + assert t.timestamp == now + +def test_equality_with_old_datetime( ): + """Tests to make sure comparisons with datetime objects + before the epoch work.""" + connect( 'test_mongorm' ) + + class TestDateTime(Document): + timestamp = DateTimeField( ) + + bestDaysOfMyLife = datetime( 1969, 1, 1, 12, 0, 0, 9000 ) + + if PYTZ: + tz = timezone( "Australia/Sydney" ) + bestDaysOfMyLife = tz.localize( bestDaysOfMyLife ) + bestDaysOfMyLife = tz.normalize( bestDaysOfMyLife ) + + t = TestDateTime( timestamp=bestDaysOfMyLife ) + t.save( ) + + assert t.timestamp == bestDaysOfMyLife + + t = TestDateTime.objects.get( pk=t.id ) + + assert t.timestamp == bestDaysOfMyLife + + try: + t.timestamp = bestDaysOfMyLife.isoformat( ) + t.save( ) + except ValueError: + assert not ISO8601 + else: + assert ISO8601 + + if ISO8601: + t = TestDateTime.objects.get( pk=t.id ) + + assert t.timestamp == bestDaysOfMyLife diff --git a/tests/test_get.py b/tests/test_get.py index 12be8e3..1d44572 100644 --- a/tests/test_get.py +++ b/tests/test_get.py @@ -1,5 +1,6 @@ from mongorm import * -from pymongo.dbref import DBRef +from bson.dbref import DBRef +from bson.objectid import ObjectId from pytest import raises @@ -28,4 +29,20 @@ class TestDocument(Document): item2.save() with raises(TestDocument.MultipleObjectsReturned): - TestDocument.objects.get(s="hello") \ No newline at end of file + TestDocument.objects.get(s="hello") + +def test_non_existing_document( ): + """Tests to make sure non-existing documents raise the correct error.""" + connect( 'test_mongorm' ) + + class TestDocument(Document): + s = StringField( ) + + TestDocument.objects.delete( ) + + item = TestDocument( ) + item._is_lazy = True + item._data['_id'] = 123 + + with raises(TestDocument.DoesNotExist): + item.s diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index 8892e13..64eed80 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -1,6 +1,9 @@ from mongorm import * -from pymongo.objectid import ObjectId +try: + from pymongo.objectid import ObjectId +except ImportError: + from bson.objectid import ObjectId def setup_module(module): DocumentRegistry.clear( ) diff --git a/tests/test_pickle.py b/tests/test_pickle.py new file mode 100644 index 0000000..c8c0310 --- /dev/null +++ b/tests/test_pickle.py @@ -0,0 +1,90 @@ +from mongorm import Document, DocumentRegistry, StringField, connect + +try: + import cPickle as pickle +except ImportError: + import pickle + +class TestPickledDocument(Document): + s = StringField( ) + +def setup_module( module ): + DocumentRegistry.registerDocument( "TestPickledDocument", TestPickledDocument ) + +def teardown_module( module ): + TestPickledDocument.objects.delete( ) + DocumentRegistry.clear( ) + +def test_pickle( ): + """Tests to make sure pickling works.""" + connect( 'test_mongorm' ) + + assert DocumentRegistry.hasDocument( "TestPickledDocument" ) + + cucumber = TestPickledDocument( s="spam" ) + cucumber.save( ) + + assert cucumber == TestPickledDocument.objects.get( s="spam" ) + + gherkin = pickle.dumps( cucumber ) + + assert pickle.loads( gherkin ) == cucumber + + assert pickle.loads( gherkin ) == TestPickledDocument.objects.get( s="spam" ) + +def test_binary_pickle( ): + """Tests to make sure binary pickling works.""" + connect( 'test_mongorm' ) + + assert DocumentRegistry.hasDocument( "TestPickledDocument" ) + + cucumber = TestPickledDocument( s="eggs" ) + cucumber.save( ) + + assert cucumber == TestPickledDocument.objects.get( s="eggs" ) + + gherkin = pickle.dumps( cucumber, pickle.HIGHEST_PROTOCOL ) + + assert pickle.loads( gherkin ) == cucumber + + assert pickle.loads( gherkin ) == TestPickledDocument.objects.get( s="eggs" ) + +def test_deleted_pickle( ): + """Tests to make sure deleted objects can be unpickled.""" + connect( 'test_mongorm' ) + + assert DocumentRegistry.hasDocument( "TestPickledDocument" ) + + cucumber = TestPickledDocument( s="onions" ) + cucumber.save( ) + + assert cucumber == TestPickledDocument.objects.get( s="onions" ) + + gherkin = pickle.dumps( cucumber, pickle.HIGHEST_PROTOCOL ) + + TestPickledDocument.objects.filter( pk=cucumber.id ).delete( ) + assert TestPickledDocument.objects.filter( pk=cucumber.id ).count( ) == 0 + + assert pickle.loads( gherkin ).s == "onions" + assert pickle.loads( gherkin ) == cucumber + +def test_modified_pickle( ): + """Tests to make sure pickled objects are updated.""" + connect( 'test_mongorm' ) + + assert DocumentRegistry.hasDocument( "TestPickledDocument" ) + + cucumber = TestPickledDocument( s="cabbage" ) + cucumber.save( ) + + assert cucumber == TestPickledDocument.objects.get( s="cabbage" ) + + gherkin = pickle.dumps( cucumber, pickle.HIGHEST_PROTOCOL ) + + cucumber.s = "kimchi" + cucumber.save( ) + + assert pickle.loads( gherkin ) == cucumber + + assert TestPickledDocument.objects.filter( s="cabbage" ).count( ) == 0 + assert pickle.loads( gherkin ) == TestPickledDocument.objects.get( s="kimchi" ) diff --git a/tests/test_queries.py b/tests/test_queries.py index d0533c3..e329c6d 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,4 +1,8 @@ +# -*- coding: utf8 -*- + from mongorm import * +from pymongo import ReadPreference +from pytest import raises def teardown_module(module): DocumentRegistry.clear( ) @@ -147,6 +151,72 @@ class TestAndOr(Document): ]}, ]} +def test_do_merge_or( ): + """Tests to make sure do_merge works with 'or' operator""" + connect( 'test_mongorm' ) + + class TestAndOr(Document): + name = StringField( ) + path = StringField( ) + index = ListField( StringField( ) ) + + query = Q( name="spam" ) | Q( name="eggs" ) + assert query.toMongo( TestAndOr ) == { + '$or': [{'name': "spam"}, {'name': "eggs"}] + } + + query &= Q( path=u"Green Midget Café" ) + assert query.toMongo( TestAndOr ) == { + '$or': [{'name': "spam"}, {'name': "eggs"}], + 'path': u"Green Midget Café" + } + + query |= Q( index='11' ) + assert query.toMongo( TestAndOr ) == { + '$or': [{ + '$or': [{'name': "spam"}, {'name': "eggs"}], + 'path': u"Green Midget Café" + }, { + 'index': '11' + }] + } + +def test_do_merge_and( ): + """Tests to make sure do_merge works with 'and' operator""" + connect( 'test_mongorm' ) + + class TestAndOr(Document): + name = StringField( ) + path = StringField( ) + index = ListField( StringField( ) ) + + query = Q( name="spam" ) & Q( name="eggs" ) + assert query.toMongo( TestAndOr ) == { + '$and': [ + {'name': "spam"}, {'name': "eggs"} + ] + } + + query &= Q( path=u"Green Midget Café" ) + assert query.toMongo( TestAndOr ) == { + '$and': [ + {'name': "spam"}, {'name': "eggs"} + ], + 'path': u"Green Midget Café" + } + + query &= Q( index='123' ) & Q( index='456' ) + assert query.toMongo( TestAndOr ) == { + '$and': [{ + '$and': [ + {'name': "spam"}, {'name': "eggs"} + ], + 'path': u"Green Midget Café" + }, { + '$and': [{'index': "123"}, {'index': "456"}] + }] + } + def test_referencefield_none( ): """Make sure ReferenceField can be searched for None""" connect( 'test_mongorm' ) @@ -187,4 +257,347 @@ class TestPush(Document): pushAll__names=['123', '456'] ) == { '$pushAll': {'names': ['123', '456']} - } \ No newline at end of file + } + +def test_in_operator( ): + """Tests in operator works with lists""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + assert Q( name__in=[] ).toMongo( Test ) \ + == {'name': {'$in': []}} + + assert Q( name__in=['eggs', 'spam'] ).toMongo( Test ) \ + == {'name': {'$in': ['eggs', 'spam']}} + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.filter( name__in=[] ).count( ) == 0 + assert Test.objects.filter( name__in=['spam'] ).count( ) == 1 + assert Test.objects.filter( name__in=['eggs'] ).count( ) == 1 + assert Test.objects.filter( name__in=['spam', 'eggs'] ).count( ) == 2 + +def test_in_iter_operator( ): + """Tests in operator works with iterators""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + assert Q( name__in={} ).toMongo( Test ) \ + == {'name': {'$in': []}} + + assert Q( name__in=set(['eggs', 'spam']) ).toMongo( Test ) \ + == {'name': {'$in': ['eggs', 'spam']}} + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + def test_gen( ): + for item in ('eggs', 'spam'): + yield item + + assert Test.objects.filter( name__in=() ).count( ) == 0 + assert Test.objects.filter( name__in={'spam': True} ).count( ) == 1 + assert Test.objects.filter( name__in=frozenset(['eggs']) ).count( ) == 1 + assert Test.objects.filter( name__in=test_gen() ).count( ) == 2 + +def test_in_operator_with_ref( ): + """Tests in operator works with references""" + connect( 'test_mongorm' ) + + class TestUser(Document): + name = StringField( ) + + class TestOrder(Document): + user = ReferenceField( TestUser ) + breakfast = StringField( ) + + # Clear objects so that counts will be correct + TestUser.objects.all( ).delete( ) + TestOrder.objects.all( ).delete( ) + + man = TestUser( name="Eric Idle" ) + wife = TestUser( name="Graham Chapman" ) + man.save( ) + wife.save( ) + + assert TestUser.objects.filter( name__in=["Eric Idle", "Graham Chapman"] ).count( ) == 2 + + TestOrder( user=man, breakfast="spam spam spam beans spam" ).save( ) + TestOrder( user=wife, breakfast="bacon and eggs" ).save( ) + + assert TestOrder.objects.filter( user=man ).count( ) == 1 + assert TestOrder.objects.filter( user=wife ).count( ) == 1 + assert TestOrder.objects.filter( user__in=[man, wife] ).count( ) == 2 + + TestOrder( user=man, breakfast="spam spam spam spam spam" ).save( ) + + assert TestOrder.objects.filter( user__in=[man, wife] ).count( ) == 3 + assert TestOrder.objects.filter( breakfast__in=["spam spam spam", "bacon and eggs"] ).count( ) == 1 + assert TestOrder.objects.filter( breakfast__in=["spam spam spam spam spam", "bacon and eggs"] ).count( ) == 2 + assert TestOrder.objects.filter( breakfast__in=[ + "spam spam spam beans spam", + "spam spam spam spam spam", + "bacon and eggs" + ] ).count( ) == 3 + assert TestOrder.objects.filter( user__in=[wife], breakfast__in=["spam spam spam spam spam"] ).count( ) == 0 + assert TestOrder.objects.filter( user__in=[man], breakfast__in=["spam spam spam spam spam"] ).count( ) == 1 + assert TestOrder.objects.filter( user__in=[man, wife], breakfast__in=[ + "spam spam spam spam spam", + "spam spam spam beans spam", + "bacon and eggs" + ] ).count( ) == 3 + +def test_multiple_iteration( ): + """Tests multiple iterators work""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Add some objects to the collection in case + Test( name="John" ).save( ) + Test( name="Eric" ).save( ) + Test( name="Graham" ).save( ) + + assert Test.objects.count( ) >= 3 + + query = Test.objects.all( ) + it1 = iter(query) + it2 = iter(query) + + for i in xrange(Test.objects.count( )): + assert next(it1) == next(it2) + +def test_secondary_read_pref( ): + """Tests read_preference works""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Add some objects to the collection in case + Test( name="John" ).save( ) + Test( name="Eric" ).save( ) + Test( name="Graham" ).save( ) + + assert Test.objects.read_preference( 'secondary' ).count( ) >= 3 + assert Test.objects.filter( name="John" ).read_preference( ReadPreference.SECONDARY )[0].name == "John" + +def test_slice_projection( ): + """Tests slice projection works""" + connect( 'test_mongorm' ) + + class TestArray(Document): + names = ListField( StringField( ) ) + + # Add some objects to the collection in case + chaps = TestArray( names=["John", "Eric", "Graham"] ) + chaps.save( ) + + assert TestArray.objects.filter( pk=chaps.id ).fields( names__slice=1 )[0].names == ["John"] + assert TestArray.objects.fields( names__slice=1 ).get( pk=chaps.id ).names == ["John"] + assert TestArray.objects.fields( names__slice=-1 ).get( pk=chaps.id ).names == ["Graham"] + assert TestArray.objects.fields( names__slice=[1, 1] ).get( pk=chaps.id ).names == ["Eric"] + assert TestArray.objects.fields( names__slice=4 ).get( pk=chaps.id ).names == ["John", "Eric", "Graham"] + +def test_nin_operator( ): + """Tests nin (not in) operator works with lists""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + assert Q( name__nin=[] ).toMongo( Test ) \ + == {'name': {'$nin': []}} + + assert Q( name__nin=['eggs', 'spam'] ).toMongo( Test ) \ + == {'name': {'$nin': ['eggs', 'spam']}} + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.filter( name__nin=[] ).count( ) == 2 + assert Test.objects.filter( name__nin=['spam'] ).count( ) == 1 + assert Test.objects.filter( name__nin=['eggs'] ).count( ) == 1 + assert Test.objects.filter( name__nin=['spam', 'eggs'] ).count( ) == 0 + +def test_nin_iter_operator( ): + """Tests nin (not in) operator works with iterators""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + assert Q( name__nin={} ).toMongo( Test ) \ + == {'name': {'$nin': []}} + + assert Q( name__nin=set(['eggs', 'spam']) ).toMongo( Test ) \ + == {'name': {'$nin': ['eggs', 'spam']}} + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + def test_gen( ): + for item in ('eggs', 'spam'): + yield item + + assert Test.objects.filter( name__nin=() ).count( ) == 2 + assert Test.objects.filter( name__nin={'spam': True} ).count( ) == 1 + assert Test.objects.filter( name__nin=frozenset(['eggs']) ).count( ) == 1 + assert Test.objects.filter( name__nin=test_gen() ).count( ) == 0 + +def test_nin_operator_with_ref( ): + """Tests nin (not in) operator works with references""" + connect( 'test_mongorm' ) + + class TestUser(Document): + name = StringField( ) + + class TestOrder(Document): + user = ReferenceField( TestUser ) + breakfast = StringField( ) + + # Clear objects so that counts will be correct + TestUser.objects.all( ).delete( ) + TestOrder.objects.all( ).delete( ) + + man = TestUser( name="Eric Idle" ) + wife = TestUser( name="Graham Chapman" ) + man.save( ) + wife.save( ) + + assert TestUser.objects.filter( name__nin=[] ).count( ) == 2 + assert TestUser.objects.filter( name__nin=["Eric Idle", "Graham Chapman"] ).count( ) == 0 + + TestOrder( user=man, breakfast="spam spam spam beans spam" ).save( ) + TestOrder( user=wife, breakfast="bacon and eggs" ).save( ) + + assert TestOrder.objects.filter( user__ne=man ).count( ) == 1 + assert TestOrder.objects.filter( user__ne=wife ).count( ) == 1 + assert TestOrder.objects.filter( user__nin=[man, wife] ).count( ) == 0 + + TestOrder( user=man, breakfast="spam spam spam spam spam" ).save( ) + + assert TestOrder.objects.filter( user__nin=[man, wife] ).count( ) == 0 + assert TestOrder.objects.filter( breakfast__nin=["spam spam spam", "bacon and eggs"] ).count( ) == 2 + assert TestOrder.objects.filter( breakfast__nin=["spam spam spam spam spam", "bacon and eggs"] ).count( ) == 1 + assert TestOrder.objects.filter( breakfast__nin=[ + "spam spam spam beans spam", + "spam spam spam spam spam", + "bacon and eggs" + ] ).count( ) == 0 + assert TestOrder.objects.filter( user__nin=[wife], breakfast__nin=["spam spam spam spam spam"] ).count( ) == 1 + assert TestOrder.objects.filter( user__nin=[man], breakfast__nin=["spam spam spam spam spam"] ).count( ) == 1 + assert TestOrder.objects.filter( user__nin=[man, wife], breakfast__nin=[ + "spam spam spam spam spam", + "spam spam spam beans spam", + "bacon and eggs" + ] ).count( ) == 0 + +def test_dict_queries( ): + """Tests dicts as a whole can be queried.""" + connect( 'test_mongorm' ) + + class TestDict(Document): + data = DictField( ) + + assert Q( data={} ).toMongo( TestDict ) == {'data': {}} + assert Q( data__gt={} ).toMongo( TestDict ) == {'data': {'$gt': {}}} + + # Clear objects so that counts will be correct + TestDict.objects.all( ).delete( ) + + TestDict( data={} ).save( ) + TestDict( data={"has": "something"} ).save( ) + + assert TestDict.objects.all( ).count( ) == 2 + + assert TestDict.objects.filter( data={} ).count( ) == 1 + assert TestDict.objects.filter( data__gt={} ).count( ) == 1 + +def test_subtype_queries( ): + """Tests querying objects based on their type.""" + connect( 'test_mongorm' ) + + class TestDocument(Document): + data = StringField( ) + + class TestSubDocumentA(TestDocument): + pass + + class TestSubDocumentB(TestDocument): + pass + + class TestSubDocumentC(TestDocument): + pass + + class TestOtherDocument(Document): + data = StringField( ) + + # Clear objects so that counts will be correct + TestDocument.objects.all( ).delete( ) + + TestSubDocumentA( data='spam' ).save( ) + TestSubDocumentB( data='spam' ).save( ) + TestSubDocumentC( data='spam' ).save( ) + + assert TestDocument.objects.all( ).count( ) == 3 + assert TestSubDocumentA.objects.all( ).count( ) == 1 + assert TestSubDocumentB.objects.all( ).count( ) == 1 + assert TestSubDocumentC.objects.all( ).count( ) == 1 + + assert TestDocument.objects.subtypes( TestSubDocumentA ).all( ).count( ) == 1 + assert TestDocument.objects.subtypes( TestSubDocumentB ).all( ).count( ) == 1 + assert TestDocument.objects.subtypes( TestSubDocumentC ).all( ).count( ) == 1 + + assert type(TestDocument.objects.subtypes( TestSubDocumentA )[0]) == TestSubDocumentA + assert type(TestDocument.objects.subtypes( TestSubDocumentB )[0]) == TestSubDocumentB + assert type(TestDocument.objects.subtypes( TestSubDocumentC )[0]) == TestSubDocumentC + + assert type(TestDocument.objects.subtypes( TestSubDocumentA ).only( '_id' )[0]) == TestSubDocumentA + assert type(TestDocument.objects.subtypes( TestSubDocumentB ).only( '_id' )[0]) == TestSubDocumentB + assert type(TestDocument.objects.subtypes( TestSubDocumentC ).only( '_id' )[0]) == TestSubDocumentC + + assert type(TestDocument.objects.subtypes( TestSubDocumentA ).ignore( 'data' )[0]) == TestSubDocumentA + assert type(TestDocument.objects.subtypes( TestSubDocumentB ).ignore( 'data' )[0]) == TestSubDocumentB + assert type(TestDocument.objects.subtypes( TestSubDocumentC ).ignore( 'data' )[0]) == TestSubDocumentC + + assert TestDocument.objects.subtypes( TestSubDocumentA, TestSubDocumentB ).all( ).count( ) == 2 + assert TestDocument.objects.subtypes( TestSubDocumentA, TestSubDocumentB, TestSubDocumentC ).all( ).count( ) == 3 + + with raises( TypeError ): + TestOtherDocument.objects.subtypes( TestSubDocumentA ).count( ) + + assert TestDocument.objects.subtypes( TestSubDocumentA ).all( ).order_by( 'data' ).count( ) == 1 + +def test_datetime_queries( ): + """Tests queries with datetime fields.""" + + from datetime import datetime + + class TestDateTime(Document): + timestamp = DateTimeField( ) + + now = datetime.now( ).replace( microsecond=0 ) + + assert Q( timestamp=now ).toMongo( TestDateTime) == {'timestamp': now} + assert Q( timestamp__gte=now ).toMongo( TestDateTime) == {'timestamp': {'$gte': now}} + assert Q( timestamp__lt=now ).toMongo( TestDateTime) == {'timestamp': {'$lt': now}} + assert Q( timestamp__exists=True ).toMongo( TestDateTime) == {'timestamp': {'$exists': True}} diff --git a/tests/test_safety.py b/tests/test_safety.py new file mode 100644 index 0000000..bd9c02e --- /dev/null +++ b/tests/test_safety.py @@ -0,0 +1,75 @@ +# -*- coding: utf8 -*- + +from mongorm import * + +class TestUnsafeKeys(Document): + data = SafeDictField( ) + +def teardown_module(module): + TestUnsafeKeys.objects.all( ).delete( ) + DocumentRegistry.clear( ) + +def check_safe_dict_with_data( data ): + doc = TestUnsafeKeys( data=data ) + doc.save( ) + assert doc.data == data + +def test_safe_dict_save_dot_key( ): + check_safe_dict_with_data( {'.': ''} ) + +def test_safe_dict_save_dollar_key( ): + check_safe_dict_with_data( {'$': ''} ) + +def test_safe_dict_save_nested_dot_key( ): + check_safe_dict_with_data( { + '': { + '': { + '.': None + } + } + } ) + +def test_safe_dict_save_nested_dollar_key( ): + check_safe_dict_with_data( { + '': { + '': { + '$': None + } + } + } ) + +def test_safe_dict_unicode( ): + check_safe_dict_with_data( { + u"$$$": None, + u"...": None, + u"déjà vu": True + } ) + +def test_safe_dict_raw( ): + check_safe_dict_with_data( { + r".": None, + r"$": None, + r"s/DictField/Safe&/g": True + } ) + +def test_safe_dict_query( ): + assert TestUnsafeKeys.objects.filter( data__attributes__course__name='test' ).count( ) == 0 + TestUnsafeKeys( data={ + 'attributes': { + 'course': { + 'name': 'test' + } + } + } ).save( ) + assert TestUnsafeKeys.objects.filter( data__attributes__course__name='test' ).count( ) == 1 + +def test_safe_dict_unsafe_query( ): + assert TestUnsafeKeys.objects.filter( Q( {'data__$$$__...__???': True} ) ).count( ) == 0 + TestUnsafeKeys( data={ + '$$$': { + '...': { + '???': True + } + } + } ).save( ) + assert TestUnsafeKeys.objects.filter( Q( {'data__$$$__...__???': True} ) ).count( ) == 1 diff --git a/tests/test_slice.py b/tests/test_slice.py new file mode 100644 index 0000000..dfe3e88 --- /dev/null +++ b/tests/test_slice.py @@ -0,0 +1,85 @@ +from mongorm import * + +def teardown_module(module): + DocumentRegistry.clear( ) + +def test_getitem_one( ): + """Tests you can get one item from query.""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.count( ) == 2 + assert Test.objects.order_by( 'name' )[0].name == 'eggs' + assert Test.objects.order_by( 'name' )[1].name == 'spam' + +def test_getitem_multiple( ): + """Tests you can get multiple items from query.""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.count( ) == 2 + assert [x.name for x in Test.objects.order_by( 'name' )[0:2]] == ['eggs', 'spam'] + +def test_getitem_no_start( ): + """Tests you can get items from query without specifying a start index.""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.count( ) == 2 + assert [x.name for x in Test.objects.order_by( 'name' )[:2]] == ['eggs', 'spam'] + +def test_getitem_no_end( ): + """Tests you can get items from query without specifying an end index.""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.count( ) == 2 + assert [x.name for x in Test.objects.order_by( 'name' )[0:]] == ['eggs', 'spam'] + +def test_getitem_no_indices( ): + """Tests you can get items from query without specifying any index.""" + connect( 'test_mongorm' ) + + class Test(Document): + name = StringField( ) + + # Clear objects so that counts will be correct + Test.objects.all( ).delete( ) + + Test( name='spam' ).save( ) + Test( name='eggs' ).save( ) + + assert Test.objects.count( ) == 2 + assert [x.name for x in Test.objects.order_by( 'name' )[:]] == ['eggs', 'spam'] diff --git a/tests/test_update.py b/tests/test_update.py index 7247df4..d531c83 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,5 +1,5 @@ from mongorm import * -from pymongo.dbref import DBRef +from bson.dbref import DBRef def teardown_module(module): DocumentRegistry.clear( ) @@ -11,7 +11,7 @@ class TestA(Document): assert Q( { 'data__123': 'test' } ).toMongo( TestA, forUpdate=True ) == { 'data.123': 'test' } - # children of a dictfield shouldn't be motified + # children of a dictfield shouldn't be modified fieldName = 'data__123' value = {"XXX": "YYY"} assert Q( { fieldName: value } ).toMongo( TestA, forUpdate=True )[fieldName.replace('__', '.')] \ @@ -54,4 +54,137 @@ class TestB(Document): assert TestB.objects._prepareActions( set__genericval=doc ) == {'$set': {'genericval': {'_types': ['TestB'], '_ref': DBRef('testb', doc.id)}}} - \ No newline at end of file + +def test_push_pull_operators( ): + """Tests to make sure the push & pull operators work""" + + class TestPushPull(Document): + values = ListField( StringField( ) ) + + # Clear all objects so that counts will be correct + TestPushPull.objects.all( ).delete( ) + + # Check correct mongo is being produced + + assert TestPushPull.objects._prepareActions( + push__values='spam' + ) == {'$push': {'values': 'spam'}} + + assert TestPushPull.objects._prepareActions( + pushAll__values=['spam', 'eggs'] + ) == {'$pushAll': {'values': ['spam', 'eggs']}} + + assert TestPushPull.objects._prepareActions( + pull__values='spam' + ) == {'$pull': {'values': 'spam'}} + + assert TestPushPull.objects._prepareActions( + pullAll__values=['spam', 'eggs'] + ) == {'$pullAll': {'values': ['spam', 'eggs']}} + + # OK let's check with some real data + + a = TestPushPull( values=[] ).save( ) + assert a.values == [] + + b = TestPushPull( values=['eggs'] ).save( ) + assert b.values == ['eggs'] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + push__values='spam' + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == ['spam'] + assert TestPushPull.objects.get( pk=b.id ).values == ['eggs', 'spam'] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + pushAll__values=[] + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == ['spam'] + assert TestPushPull.objects.get( pk=b.id ).values == ['eggs', 'spam'] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + pullAll__values=[] + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == ['spam'] + assert TestPushPull.objects.get( pk=b.id ).values == ['eggs', 'spam'] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + pull__values='spam' + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == [] + assert TestPushPull.objects.get( pk=b.id ).values == ['eggs'] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + pull__values='eggs' + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == [] + assert TestPushPull.objects.get( pk=b.id ).values == [] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + pushAll__values=['spam', 'eggs'] + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == ['spam', 'eggs'] + assert TestPushPull.objects.get( pk=b.id ).values == ['spam', 'eggs'] + + assert TestPushPull.objects.update( + safeUpdate=True, + updateAllDocuments=True, + pullAll__values=['spam', 'eggs'] + ) == 2 + + assert TestPushPull.objects.get( pk=a.id ).values == [] + assert TestPushPull.objects.get( pk=b.id ).values == [] + +def test_setOnInsert( ): + """Tests to make sure $setOnInsert works""" + + class TestC(Document): + name = StringField( ) + version = IntField( ) + + # Clear objects to reset counts + TestC.objects.all( ).delete( ) + + assert TestC.objects._prepareActions( + setOnInsert__name='spam', + inc__version=1 + ) == {'$setOnInsert': {'name': 'spam'}, '$inc': {'version': 1}} + + assert TestC.objects.filter( name='spam' ).count( ) == 0 + assert TestC.objects.filter( name='spam' ).update( + safeUpdate=True, + upsert=True, + modifyAndReturn=True, + setOnInsert__name='spam', + inc__version=1 + ) is None + assert TestC.objects.filter( name='spam' ).count( ) == 1 + + c = TestC.objects.filter( name='spam' ).update( + safeUpdate=True, + upsert=True, + modifyAndReturn=True, + returnAfterUpdate=True, + setOnInsert__name='eggs', + inc__version=1 + ) + assert c.name == 'spam' + assert c.version == 2