diff --git a/rest_framework_json_api/exceptions.py b/rest_framework_json_api/exceptions.py index 7ffaf256..f6b21ad6 100644 --- a/rest_framework_json_api/exceptions.py +++ b/rest_framework_json_api/exceptions.py @@ -2,12 +2,13 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import exceptions, status -from rest_framework_json_api import renderers, utils +from rest_framework_json_api import utils def rendered_with_json_api(view): + from rest_framework_json_api.renderers import JSONRenderer for renderer_class in getattr(view, 'renderer_classes', []): - if issubclass(renderer_class, renderers.JSONRenderer): + if issubclass(renderer_class, JSONRenderer): return True return False diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 1c74ffdf..6059d2a2 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -2,7 +2,7 @@ Renderers """ import copy -from collections import OrderedDict +from collections import OrderedDict, defaultdict import inflection from django.db.models import Manager @@ -13,6 +13,7 @@ import rest_framework_json_api from rest_framework_json_api import utils +from rest_framework_json_api.relations import ResourceRelatedField class JSONRenderer(renderers.JSONRenderer): @@ -313,12 +314,12 @@ def extract_relation_instance(cls, field_name, field, resource_instance, seriali return relation_instance @classmethod - def extract_included(cls, fields, resource, resource_instance, included_resources): + def extract_included(cls, fields, resource, resource_instance, included_resources, + included_cache): # this function may be called with an empty record (example: Browsable Interface) if not resource_instance: return - included_data = list() current_serializer = fields.serializer context = current_serializer.context included_serializers = utils.get_included_serializers(current_serializer) @@ -350,9 +351,6 @@ def extract_included(cls, fields, resource, resource_instance, included_resource if isinstance(relation_instance, Manager): relation_instance = relation_instance.all() - new_included_resources = [key.replace('%s.' % field_name, '', 1) - for key in included_resources - if field_name == key.split('.')[0]] serializer_data = resource.get(field_name) if isinstance(field, relations.ManyRelatedField): @@ -365,10 +363,22 @@ def extract_included(cls, fields, resource, resource_instance, included_resource continue many = field._kwargs.get('child_relation', None) is not None + + if isinstance(field, ResourceRelatedField) and not many: + already_included = serializer_data['type'] in included_cache and \ + serializer_data['id'] in included_cache[serializer_data['type']] + + if already_included: + continue + serializer_class = included_serializers[field_name] field = serializer_class(relation_instance, many=many, context=context) serializer_data = field.data + new_included_resources = [key.replace('%s.' % field_name, '', 1) + for key in included_resources + if field_name == key.split('.')[0]] + if isinstance(field, ListSerializer): serializer = field.child relation_type = utils.get_resource_type_from_serializer(serializer) @@ -387,48 +397,45 @@ def extract_included(cls, fields, resource, resource_instance, included_resource nested_resource_instance, context=serializer.context ) ) - included_data.append( - cls.build_json_resource_obj( - serializer_fields, - serializer_resource, - nested_resource_instance, - resource_type, - getattr(serializer, '_poly_force_type_resolution', False) - ) + new_item = cls.build_json_resource_obj( + serializer_fields, + serializer_resource, + nested_resource_instance, + resource_type, + getattr(serializer, '_poly_force_type_resolution', False) ) - included_data.extend( - cls.extract_included( - serializer_fields, - serializer_resource, - nested_resource_instance, - new_included_resources - ) + included_cache[new_item['type']][new_item['id']] = \ + utils.format_keys(new_item) + cls.extract_included( + serializer_fields, + serializer_resource, + nested_resource_instance, + new_included_resources, + included_cache, ) if isinstance(field, Serializer): - relation_type = utils.get_resource_type_from_serializer(field) # Get the serializer fields serializer_fields = utils.get_serializer_fields(field) if serializer_data: - included_data.append( - cls.build_json_resource_obj( - serializer_fields, serializer_data, - relation_instance, relation_type, - getattr(field, '_poly_force_type_resolution', False)) + new_item = cls.build_json_resource_obj( + serializer_fields, + serializer_data, + relation_instance, + relation_type, + getattr(field, '_poly_force_type_resolution', False) ) - included_data.extend( - cls.extract_included( - serializer_fields, - serializer_data, - relation_instance, - new_included_resources - ) + included_cache[new_item['type']][new_item['id']] = utils.format_keys(new_item) + cls.extract_included( + serializer_fields, + serializer_data, + relation_instance, + new_included_resources, + included_cache, ) - return utils.format_keys(included_data) - @classmethod def extract_meta(cls, serializer, resource): if hasattr(serializer, 'child'): @@ -529,9 +536,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None): ) json_api_data = data - json_api_included = list() # initialize json_api_meta with pagination meta or an empty dict json_api_meta = data.get('meta', {}) if isinstance(data, dict) else {} + included_cache = defaultdict(dict) if data and 'results' in data: serializer_data = data["results"] @@ -573,11 +580,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None): json_resource_obj.update({'meta': utils.format_keys(meta)}) json_api_data.append(json_resource_obj) - included = self.extract_included( - fields, resource, resource_instance, included_resources + self.extract_included( + fields, resource, resource_instance, included_resources, included_cache ) - if included: - json_api_included.extend(included) else: fields = utils.get_serializer_fields(serializer) force_type_resolution = getattr(serializer, '_poly_force_type_resolution', False) @@ -591,11 +596,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None): if meta: json_api_data.update({'meta': utils.format_keys(meta)}) - included = self.extract_included( - fields, serializer_data, resource_instance, included_resources + self.extract_included( + fields, serializer_data, resource_instance, included_resources, included_cache ) - if included: - json_api_included.extend(included) # Make sure we render data in a specific order render_data = OrderedDict() @@ -610,20 +613,11 @@ def render(self, data, accepted_media_type=None, renderer_context=None): else: render_data['data'] = json_api_data - if len(json_api_included) > 0: - # Iterate through compound documents to remove duplicates - seen = set() - unique_compound_documents = list() - for included_dict in json_api_included: - type_tuple = tuple((included_dict['type'], included_dict['id'])) - if type_tuple not in seen: - seen.add(type_tuple) - unique_compound_documents.append(included_dict) - - # Sort the items by type then by id - render_data['included'] = sorted( - unique_compound_documents, key=lambda item: (item['type'], item['id']) - ) + if included_cache: + render_data['included'] = list() + for included_type in sorted(included_cache.keys()): + for included_id in sorted(included_cache[included_type].keys()): + render_data['included'].append(included_cache[included_type][included_id]) if json_api_meta: render_data['meta'] = utils.format_keys(json_api_meta)