diff --git a/.gitignore b/.gitignore index 3177afc7..29fb669d 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,9 @@ pip-delete-this-directory.txt # Pycharm project files .idea/ +# PyTest cache +.cache/ + # Tox .tox/ diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 1c66c927..4493d47e 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -4,10 +4,11 @@ import copy from collections import OrderedDict +from django.db.models import Manager, QuerySet from django.utils import six, encoding from rest_framework import relations from rest_framework import renderers -from rest_framework.serializers import BaseSerializer, ListSerializer, ModelSerializer +from rest_framework.serializers import BaseSerializer, Serializer, ListSerializer, ModelSerializer from rest_framework.settings import api_settings from . import utils @@ -87,16 +88,19 @@ def extract_relationships(fields, resource, resource_instance): source = field.source try: - relation_instance_or_manager = getattr(resource_instance, source) + relation_instance = getattr(resource_instance, source) except AttributeError: # if the field is not defined on the model then we check the serializer # and if no value is there we skip over the field completely serializer_method = getattr(field.parent, source, None) if serializer_method and hasattr(serializer_method, '__call__'): - relation_instance_or_manager = serializer_method(resource_instance) + relation_instance = serializer_method(resource_instance) else: continue + if isinstance(relation_instance, Manager): + relation_instance = relation_instance.all() + relation_type = utils.get_related_resource_type(field) if isinstance(field, relations.HyperlinkedIdentityField): @@ -104,8 +108,8 @@ def extract_relationships(fields, resource, resource_instance): relation_data = list() # Don't try to query an empty relation - relation_queryset = relation_instance_or_manager.all() \ - if relation_instance_or_manager is not None else list() + relation_queryset = relation_instance \ + if relation_instance is not None else list() for related_object in relation_queryset: relation_data.append( @@ -137,7 +141,7 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, (relations.PrimaryKeyRelatedField, relations.HyperlinkedRelatedField)): - relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None + relation_id = relation_instance.pk if resource.get(field_name) else None relation_data = { 'data': ( @@ -176,11 +180,15 @@ def extract_relationships(fields, resource, resource_instance): continue relation_data = list() - for related_object in relation_instance_or_manager.all(): - related_object_type = utils.get_instance_or_manager_resource_type(related_object) + for nested_resource_instance in relation_instance: + nested_resource_instance_type = ( + relation_type or + utils.get_resource_type_from_instance(nested_resource_instance) + ) + relation_data.append(OrderedDict([ - ('type', related_object_type), - ('id', encoding.force_text(related_object.pk)) + ('type', nested_resource_instance_type), + ('id', encoding.force_text(nested_resource_instance.pk)) ])) data.update({ field_name: { @@ -192,15 +200,19 @@ def extract_relationships(fields, resource, resource_instance): }) continue - if isinstance(field, ListSerializer): + if isinstance(field, ListSerializer) and relation_instance is not None: relation_data = list() serializer_data = resource.get(field_name) - resource_instance_queryset = list(relation_instance_or_manager.all()) + resource_instance_queryset = list(relation_instance) if isinstance(serializer_data, list): for position in range(len(serializer_data)): nested_resource_instance = resource_instance_queryset[position] - nested_resource_instance_type = utils.get_resource_type_from_instance(nested_resource_instance) + nested_resource_instance_type = ( + relation_type or + utils.get_resource_type_from_instance(nested_resource_instance) + ) + relation_data.append(OrderedDict([ ('type', nested_resource_instance_type), ('id', encoding.force_text(nested_resource_instance.pk)) @@ -209,16 +221,13 @@ def extract_relationships(fields, resource, resource_instance): data.update({field_name: {'data': relation_data}}) continue - if isinstance(field, ModelSerializer): - relation_model = field.Meta.model - relation_type = utils.format_resource_type(relation_model.__name__) - + if isinstance(field, Serializer): data.update({ field_name: { 'data': ( OrderedDict([ ('type', relation_type), - ('id', encoding.force_text(relation_instance_or_manager.pk)) + ('id', encoding.force_text(relation_instance.pk)) ]) if resource.get(field_name) else None) } }) @@ -256,16 +265,19 @@ def extract_included(fields, resource, resource_instance, included_resources): continue try: - relation_instance_or_manager = getattr(resource_instance, field_name) + relation_instance = getattr(resource_instance, field_name) except AttributeError: try: # For ManyRelatedFields if `related_name` is not set we need to access `foo_set` from `source` - relation_instance_or_manager = getattr(resource_instance, field.child_relation.source) + relation_instance = getattr(resource_instance, field.child_relation.source) except AttributeError: if not hasattr(current_serializer, field.source): continue serializer_method = getattr(current_serializer, field.source) - relation_instance_or_manager = serializer_method(resource_instance) + relation_instance = serializer_method(resource_instance) + + if isinstance(relation_instance, Manager): + relation_instance = relation_instance.all() new_included_resources = [key.replace('%s.' % field_name, '', 1) for key in included_resources @@ -273,21 +285,21 @@ def extract_included(fields, resource, resource_instance, included_resources): serializer_data = resource.get(field_name) if isinstance(field, relations.ManyRelatedField): - serializer_class = included_serializers.get(field_name) - field = serializer_class(relation_instance_or_manager.all(), many=True, context=context) + serializer_class = included_serializers[field_name] + field = serializer_class(relation_instance, many=True, context=context) serializer_data = field.data if isinstance(field, relations.RelatedField): - serializer_class = included_serializers.get(field_name) - if relation_instance_or_manager is None: + if relation_instance is None: continue - field = serializer_class(relation_instance_or_manager, context=context) + serializer_class = included_serializers[field_name] + field = serializer_class(relation_instance, context=context) serializer_data = field.data if isinstance(field, ListSerializer): serializer = field.child relation_type = utils.get_resource_type_from_serializer(serializer) - relation_queryset = list(relation_instance_or_manager.all()) + relation_queryset = list(relation_instance) # Get the serializer fields serializer_fields = utils.get_serializer_fields(serializer) @@ -310,7 +322,7 @@ def extract_included(fields, resource, resource_instance, included_resources): ) ) - if isinstance(field, ModelSerializer): + if isinstance(field, Serializer): relation_type = utils.get_resource_type_from_serializer(field) @@ -320,11 +332,11 @@ def extract_included(fields, resource, resource_instance, included_resources): included_data.append( JSONRenderer.build_json_resource_obj( serializer_fields, serializer_data, - relation_instance_or_manager, relation_type) + relation_instance, relation_type) ) included_data.extend( JSONRenderer.extract_included( - serializer_fields, serializer_data, relation_instance_or_manager, new_included_resources + serializer_fields, serializer_data, relation_instance, new_included_resources ) ) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 261640c6..36d850e7 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -162,6 +162,12 @@ def format_resource_type(value, format_type=None, pluralize=None): def get_related_resource_type(relation): + try: + return get_resource_type_from_serializer(relation) + except AttributeError: + pass + + relation_model = None if hasattr(relation, '_meta'): relation_model = relation._meta.model elif hasattr(relation, 'model'): @@ -171,38 +177,36 @@ def get_related_resource_type(relation): relation_model = relation.get_queryset().model else: parent_serializer = relation.parent + parent_model = None if hasattr(parent_serializer, 'Meta'): - parent_model = parent_serializer.Meta.model - else: - parent_model = parent_serializer.parent.Meta.model - - if relation.source: - if relation.source != '*': - parent_model_relation = getattr(parent_model, relation.source) + parent_model = getattr(parent_serializer.Meta, 'model', None) + elif hasattr(parent_serializer, 'parent') and hasattr(parent_serializer.parent, 'Meta'): + parent_model = getattr(parent_serializer.parent.Meta, 'model', None) + + if parent_model is not None: + if relation.source: + if relation.source != '*': + parent_model_relation = getattr(parent_model, relation.source) + else: + parent_model_relation = getattr(parent_model, relation.field_name) else: - parent_model_relation = getattr(parent_model, relation.field_name) - else: - parent_model_relation = getattr(parent_model, parent_serializer.field_name) - - if hasattr(parent_model_relation, 'related'): - try: - relation_model = parent_model_relation.related.related_model - except AttributeError: - # Django 1.7 - relation_model = parent_model_relation.related.model - elif hasattr(parent_model_relation, 'field'): - relation_model = parent_model_relation.field.related.model - else: - return get_related_resource_type(parent_model_relation) - return get_resource_type_from_model(relation_model) + parent_model_relation = getattr(parent_model, parent_serializer.field_name) + + if hasattr(parent_model_relation, 'related'): + try: + relation_model = parent_model_relation.related.related_model + except AttributeError: + # Django 1.7 + relation_model = parent_model_relation.related.model + elif hasattr(parent_model_relation, 'field'): + relation_model = parent_model_relation.field.related.model + else: + return get_related_resource_type(parent_model_relation) + if relation_model is None: + raise APIException(_('Could not resolve resource type for relation %s' % relation)) -def get_instance_or_manager_resource_type(resource_instance_or_manager): - if hasattr(resource_instance_or_manager, 'model'): - return get_resource_type_from_manager(resource_instance_or_manager) - if hasattr(resource_instance_or_manager, '_meta'): - return get_resource_type_from_instance(resource_instance_or_manager) - pass + return get_resource_type_from_model(relation_model) def get_resource_type_from_model(model): @@ -218,7 +222,8 @@ def get_resource_type_from_queryset(qs): def get_resource_type_from_instance(instance): - return get_resource_type_from_model(instance._meta.model) + if hasattr(instance, '_meta'): + return get_resource_type_from_model(instance._meta.model) def get_resource_type_from_manager(manager): @@ -226,10 +231,15 @@ def get_resource_type_from_manager(manager): def get_resource_type_from_serializer(serializer): - if hasattr(serializer.Meta, 'resource_name'): - return serializer.Meta.resource_name - else: - return get_resource_type_from_model(serializer.Meta.model) + json_api_meta = getattr(serializer, 'JSONAPIMeta', None) + meta = getattr(serializer, 'Meta', None) + if hasattr(json_api_meta, 'resource_name'): + return json_api_meta.resource_name + elif hasattr(meta, 'resource_name'): + return meta.resource_name + elif hasattr(meta, 'model'): + return get_resource_type_from_model(meta.model) + raise AttributeError() def get_included_serializers(serializer):