Skip to content

Speed up JSONRenderer.extract_included #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rest_framework_json_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions, status

from rest_framework_json_api import renderers, utils
from rest_framework_json_api import utils


def rendered_with_json_api(view):
from rest_framework_json_api.renderers import JSONRenderer
for renderer_class in getattr(view, 'renderer_classes', []):
if issubclass(renderer_class, renderers.JSONRenderer):
if issubclass(renderer_class, JSONRenderer):
return True
return False

Expand Down
112 changes: 53 additions & 59 deletions rest_framework_json_api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Renderers
"""
import copy
from collections import OrderedDict
from collections import OrderedDict, defaultdict

import inflection
from django.db.models import Manager
Expand All @@ -13,6 +13,7 @@

import rest_framework_json_api
from rest_framework_json_api import utils
from rest_framework_json_api.relations import ResourceRelatedField


class JSONRenderer(renderers.JSONRenderer):
Expand Down Expand Up @@ -313,12 +314,12 @@ def extract_relation_instance(cls, field_name, field, resource_instance, seriali
return relation_instance

@classmethod
def extract_included(cls, fields, resource, resource_instance, included_resources):
def extract_included(cls, fields, resource, resource_instance, included_resources,
included_cache):
# this function may be called with an empty record (example: Browsable Interface)
if not resource_instance:
return

included_data = list()
current_serializer = fields.serializer
context = current_serializer.context
included_serializers = utils.get_included_serializers(current_serializer)
Expand Down Expand Up @@ -350,9 +351,6 @@ def extract_included(cls, fields, resource, resource_instance, included_resource
if isinstance(relation_instance, Manager):
relation_instance = relation_instance.all()

new_included_resources = [key.replace('%s.' % field_name, '', 1)
for key in included_resources
if field_name == key.split('.')[0]]
serializer_data = resource.get(field_name)

if isinstance(field, relations.ManyRelatedField):
Expand All @@ -365,10 +363,22 @@ def extract_included(cls, fields, resource, resource_instance, included_resource
continue

many = field._kwargs.get('child_relation', None) is not None

if isinstance(field, ResourceRelatedField) and not many:
already_included = serializer_data['type'] in included_cache and \
serializer_data['id'] in included_cache[serializer_data['type']]

if already_included:
continue

serializer_class = included_serializers[field_name]
field = serializer_class(relation_instance, many=many, context=context)
serializer_data = field.data

new_included_resources = [key.replace('%s.' % field_name, '', 1)
for key in included_resources
if field_name == key.split('.')[0]]

if isinstance(field, ListSerializer):
serializer = field.child
relation_type = utils.get_resource_type_from_serializer(serializer)
Expand All @@ -387,48 +397,45 @@ def extract_included(cls, fields, resource, resource_instance, included_resource
nested_resource_instance, context=serializer.context
)
)
included_data.append(
cls.build_json_resource_obj(
serializer_fields,
serializer_resource,
nested_resource_instance,
resource_type,
getattr(serializer, '_poly_force_type_resolution', False)
)
new_item = cls.build_json_resource_obj(
serializer_fields,
serializer_resource,
nested_resource_instance,
resource_type,
getattr(serializer, '_poly_force_type_resolution', False)
)
included_data.extend(
cls.extract_included(
serializer_fields,
serializer_resource,
nested_resource_instance,
new_included_resources
)
included_cache[new_item['type']][new_item['id']] = \
utils.format_keys(new_item)
cls.extract_included(
serializer_fields,
serializer_resource,
nested_resource_instance,
new_included_resources,
included_cache,
)

if isinstance(field, Serializer):

relation_type = utils.get_resource_type_from_serializer(field)

# Get the serializer fields
serializer_fields = utils.get_serializer_fields(field)
if serializer_data:
included_data.append(
cls.build_json_resource_obj(
serializer_fields, serializer_data,
relation_instance, relation_type,
getattr(field, '_poly_force_type_resolution', False))
new_item = cls.build_json_resource_obj(
serializer_fields,
serializer_data,
relation_instance,
relation_type,
getattr(field, '_poly_force_type_resolution', False)
)
included_data.extend(
cls.extract_included(
serializer_fields,
serializer_data,
relation_instance,
new_included_resources
)
included_cache[new_item['type']][new_item['id']] = utils.format_keys(new_item)
cls.extract_included(
serializer_fields,
serializer_data,
relation_instance,
new_included_resources,
included_cache,
)

return utils.format_keys(included_data)

@classmethod
def extract_meta(cls, serializer, resource):
if hasattr(serializer, 'child'):
Expand Down Expand Up @@ -529,9 +536,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
)

json_api_data = data
json_api_included = list()
# initialize json_api_meta with pagination meta or an empty dict
json_api_meta = data.get('meta', {}) if isinstance(data, dict) else {}
included_cache = defaultdict(dict)

if data and 'results' in data:
serializer_data = data["results"]
Expand Down Expand Up @@ -573,11 +580,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
json_resource_obj.update({'meta': utils.format_keys(meta)})
json_api_data.append(json_resource_obj)

included = self.extract_included(
fields, resource, resource_instance, included_resources
self.extract_included(
fields, resource, resource_instance, included_resources, included_cache
)
if included:
json_api_included.extend(included)
else:
fields = utils.get_serializer_fields(serializer)
force_type_resolution = getattr(serializer, '_poly_force_type_resolution', False)
Expand All @@ -591,11 +596,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
if meta:
json_api_data.update({'meta': utils.format_keys(meta)})

included = self.extract_included(
fields, serializer_data, resource_instance, included_resources
self.extract_included(
fields, serializer_data, resource_instance, included_resources, included_cache
)
if included:
json_api_included.extend(included)

# Make sure we render data in a specific order
render_data = OrderedDict()
Expand All @@ -610,20 +613,11 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
else:
render_data['data'] = json_api_data

if len(json_api_included) > 0:
# Iterate through compound documents to remove duplicates
seen = set()
unique_compound_documents = list()
for included_dict in json_api_included:
type_tuple = tuple((included_dict['type'], included_dict['id']))
if type_tuple not in seen:
seen.add(type_tuple)
unique_compound_documents.append(included_dict)

# Sort the items by type then by id
render_data['included'] = sorted(
unique_compound_documents, key=lambda item: (item['type'], item['id'])
)
if included_cache:
render_data['included'] = list()
for included_type in sorted(included_cache.keys()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .keys() on this line and the next is not necessary. sorted applies to the keys of a dictionary by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, this could be compacted with a nested list comprehension:

render_data['included'] = [
    instance for instance_type, type_values in sorted(instance_cache.items())
             for id, instance in sorted(type_values.items())]

This works because sorted on items works on the keys. This version eliminates a bunch of append operations. There's probably a small perf gain here, but I bet it's tiny. Here's an example to illustrate:

>>> instance_cache = {'c': {'b': 6, 'a': 5}, 'a': {'b': 2, 'a': 1}, 'b': {'b': 4, 'a': 3}}
>>> [instance for instance_type, type_values in sorted(instance_cache.items()) for id, instance in sorted(type_values.items())]
[1, 2, 3, 4, 5, 6]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind I prefer explicitness in code. Most people know that iterating a dictionary in Python yields keys, but using keys() makes it obvious for everyone – seniors and juniors from whatever programming language background.

I feel similarly about that nested list comprehension. I didn't know that a thing like that existed in Python, but I have to say it is very difficult to read. Seems arcane.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can appreciate your point of view here. Thanks for at least considering it.

for included_id in sorted(included_cache[included_type].keys()):
render_data['included'].append(included_cache[included_type][included_id])

if json_api_meta:
render_data['meta'] = utils.format_keys(json_api_meta)
Expand Down