Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Hyperlinked PK optimization. Closes #1872.
  • Loading branch information
lovelydinosaur committed Dec 9, 2014
commit 720a37d3dedc501968bebaca3a339c72392b9c81
64 changes: 37 additions & 27 deletions rest_framework/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,34 @@ def get_queryset(self):
queryset = queryset.all()
return queryset

def use_pk_only_optimization(self):
return False

def get_attribute(self, instance):
if self.use_pk_only_optimization():
try:
# Optimized case, return a mock object only containing the pk attribute.
instance = get_attribute(instance, self.source_attrs[:-1])
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError:
pass

# Standard case, return the object instance.
return get_attribute(instance, self.source_attrs)

def get_iterable(self, instance, source_attrs):
relationship = get_attribute(instance, source_attrs)
return relationship.all() if (hasattr(relationship, 'all')) else relationship
relationship = relationship.all() if (hasattr(relationship, 'all')) else relationship

if self.use_pk_only_optimization():
# Optimized case, return mock objects only containing the pk attribute.
return [
PKOnlyObject(pk=pk)
for pk in relationship.values_list('pk', flat=True)
]

# Standard case, return the object instances.
return relationship

@property
def choices(self):
Expand Down Expand Up @@ -120,6 +145,9 @@ class PrimaryKeyRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}

def use_pk_only_optimization(self):
return True

def to_internal_value(self, data):
try:
return self.get_queryset().get(pk=data)
Expand All @@ -128,32 +156,6 @@ def to_internal_value(self, data):
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)

def get_attribute(self, instance):
# We customize `get_attribute` here for performance reasons.
# For relationships the instance will already have the pk of
# the related object. We return this directly instead of returning the
# object itself, which would require a database lookup.
try:
instance = get_attribute(instance, self.source_attrs[:-1])
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError:
return get_attribute(instance, self.source_attrs)

def get_iterable(self, instance, source_attrs):
# For consistency with `get_attribute` we're using `serializable_value()`
# here. Typically there won't be any difference, but some custom field
# types might return a non-primitive value for the pk otherwise.
#
# We could try to get smart with `values_list('pk', flat=True)`, which
# would be better in some case, but would actually end up with *more*
# queries if the developer is using `prefetch_related` across the
# relationship.
relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs)
return [
PKOnlyObject(pk=item.serializable_value('pk'))
for item in relationship
]

def to_representation(self, value):
return value.pk

Expand Down Expand Up @@ -184,6 +186,9 @@ def __init__(self, view_name=None, **kwargs):

super(HyperlinkedRelatedField, self).__init__(**kwargs)

def use_pk_only_optimization(self):
return self.lookup_field == 'pk'

def get_object(self, view_name, view_args, view_kwargs):
"""
Return the object corresponding to a matched URL.
Expand Down Expand Up @@ -285,6 +290,11 @@ def __init__(self, view_name=None, **kwargs):
kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)

def use_pk_only_optimization(self):
# We have the complete object instance already. We don't need
# to run the 'only get the pk for this relationship' code.
return False


class SlugRelatedField(RelatedField):
"""
Expand Down
12 changes: 8 additions & 4 deletions tests/test_relations_hyperlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def test_many_to_many_retrieve(self):
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)

def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
Expand All @@ -99,7 +100,8 @@ def test_reverse_many_to_many_retrieve(self):
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)

def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
Expand Down Expand Up @@ -197,7 +199,8 @@ def test_foreign_key_retrieve(self):
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this fix, this is 4.

self.assertEqual(serializer.data, expected)

def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
Expand All @@ -206,7 +209,8 @@ def test_reverse_foreign_key_retrieve(self):
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(3):
self.assertEqual(serializer.data, expected)

def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
Expand Down