diff --git a/example/models.py b/example/models.py index 7ded4ebf..d0dcf3e4 100644 --- a/example/models.py +++ b/example/models.py @@ -76,7 +76,7 @@ def __str__(self): @python_2_unicode_compatible class Comment(BaseModel): - entry = models.ForeignKey(Entry) + entry = models.ForeignKey(Entry, related_name='comments') body = models.TextField() author = models.ForeignKey( Author, diff --git a/example/serializers.py b/example/serializers.py index c5eb1a18..743d10e6 100644 --- a/example/serializers.py +++ b/example/serializers.py @@ -55,7 +55,7 @@ def __init__(self, *args, **kwargs): body_format = serializers.SerializerMethodField() # many related from model comments = relations.ResourceRelatedField( - source='comment_set', many=True, read_only=True) + many=True, read_only=True) # many related from serializer suggested = relations.SerializerMethodResourceRelatedField( source='get_suggested', model=Entry, many=True, read_only=True, diff --git a/example/tests/integration/test_includes.py b/example/tests/integration/test_includes.py index 607f48cc..622c8c13 100644 --- a/example/tests/integration/test_includes.py +++ b/example/tests/integration/test_includes.py @@ -18,7 +18,7 @@ def test_included_data_on_list(multiple_entries, client, query='?include=comment assert [x.get('type') for x in included] == ['comments', 'comments'], 'List included types are incorrect' comment_count = len([resource for resource in included if resource["type"] == "comments"]) - expected_comment_count = sum([entry.comment_set.count() for entry in multiple_entries]) + expected_comment_count = sum([entry.comments.count() for entry in multiple_entries]) assert comment_count == expected_comment_count, 'List comment count is incorrect' @@ -33,7 +33,7 @@ def test_included_data_on_detail(single_entry, client, query='?include=comments' assert [x.get('type') for x in included] == ['comments'], 'Detail included types are incorrect' comment_count = len([resource for resource in included if resource["type"] == "comments"]) - expected_comment_count = single_entry.comment_set.count() + expected_comment_count = single_entry.comments.count() assert comment_count == expected_comment_count, 'Detail comment count is incorrect' @@ -81,16 +81,16 @@ def test_deep_included_data_on_list(multiple_entries, client): ], 'List included types are incorrect' comment_count = len([resource for resource in included if resource["type"] == "comments"]) - expected_comment_count = sum([entry.comment_set.count() for entry in multiple_entries]) + expected_comment_count = sum([entry.comments.count() for entry in multiple_entries]) assert comment_count == expected_comment_count, 'List comment count is incorrect' author_count = len([resource for resource in included if resource["type"] == "authors"]) expected_author_count = sum( - [entry.comment_set.filter(author__isnull=False).count() for entry in multiple_entries]) + [entry.comments.filter(author__isnull=False).count() for entry in multiple_entries]) assert author_count == expected_author_count, 'List author count is incorrect' author_bio_count = len([resource for resource in included if resource["type"] == "authorBios"]) - expected_author_bio_count = sum([entry.comment_set.filter( + expected_author_bio_count = sum([entry.comments.filter( author__bio__isnull=False).count() for entry in multiple_entries]) assert author_bio_count == expected_author_bio_count, 'List author bio count is incorrect' @@ -107,7 +107,7 @@ def test_deep_included_data_on_list(multiple_entries, client): author_count = len([resource for resource in included if resource["type"] == "authors"]) expected_author_count = sum( [entry.authors.count() for entry in multiple_entries] + - [entry.comment_set.filter(author__isnull=False).count() for entry in multiple_entries]) + [entry.comments.filter(author__isnull=False).count() for entry in multiple_entries]) assert author_count == expected_author_count, 'List author count is incorrect' @@ -122,9 +122,9 @@ def test_deep_included_data_on_detail(single_entry, client): 'Detail included types are incorrect' comment_count = len([resource for resource in included if resource["type"] == "comments"]) - expected_comment_count = single_entry.comment_set.count() + expected_comment_count = single_entry.comments.count() assert comment_count == expected_comment_count, 'Detail comment count is incorrect' author_bio_count = len([resource for resource in included if resource["type"] == "authorBios"]) - expected_author_bio_count = single_entry.comment_set.filter(author__bio__isnull=False).count() + expected_author_bio_count = single_entry.comments.filter(author__bio__isnull=False).count() assert author_bio_count == expected_author_bio_count, 'Detail author bio count is incorrect' diff --git a/example/tests/test_relations.py b/example/tests/test_relations.py index adbf4984..dc252e7e 100644 --- a/example/tests/test_relations.py +++ b/example/tests/test_relations.py @@ -104,7 +104,7 @@ def test_deserialize_many_to_many_relation(self): author_pks = Author.objects.values_list('pk', flat=True) authors = [{'type': type_string, 'id': pk} for pk in author_pks] - serializer = EntryModelSerializer(data={'authors': authors, 'comment_set': []}) + serializer = EntryModelSerializer(data={'authors': authors, 'comments': []}) self.assertTrue(serializer.is_valid()) self.assertEqual(len(serializer.validated_data['authors']), Author.objects.count()) @@ -112,9 +112,9 @@ def test_deserialize_many_to_many_relation(self): self.assertIsInstance(author, Author) def test_read_only(self): - serializer = EntryModelSerializer(data={'authors': [], 'comment_set': [{'type': 'Comments', 'id': 2}]}) + serializer = EntryModelSerializer(data={'authors': [], 'comments': [{'type': 'Comments', 'id': 2}]}) serializer.is_valid(raise_exception=True) - self.assertNotIn('comment_set', serializer.validated_data) + self.assertNotIn('comments', serializer.validated_data) def test_invalid_resource_id_object(self): comment = {'body': 'testing 123', 'entry': {'type': 'entry'}, 'author': {'id': '5'}} @@ -136,8 +136,8 @@ class EntryFKSerializer(serializers.Serializer): class EntryModelSerializer(serializers.ModelSerializer): authors = ResourceRelatedField(many=True, queryset=Author.objects) - comment_set = ResourceRelatedField(many=True, read_only=True) + comments = ResourceRelatedField(many=True, read_only=True) class Meta: model = Entry - fields = ('authors', 'comment_set') + fields = ('authors', 'comments') diff --git a/example/tests/test_views.py b/example/tests/test_views.py index d802ef52..fcf35e99 100644 --- a/example/tests/test_views.py +++ b/example/tests/test_views.py @@ -158,7 +158,7 @@ def test_post_to_one_relationship_should_fail(self): assert response.status_code == 405, response.content.decode() def test_post_to_many_relationship_with_no_change(self): - url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id) + url = '/entries/{}/relationships/comments'.format(self.first_entry.id) request_data = { 'data': [{'type': format_resource_type('Comment'), 'id': str(self.first_comment.id)}, ] } @@ -166,7 +166,7 @@ def test_post_to_many_relationship_with_no_change(self): assert response.status_code == 204, response.content.decode() def test_post_to_many_relationship_with_change(self): - url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id) + url = '/entries/{}/relationships/comments'.format(self.first_entry.id) request_data = { 'data': [{'type': format_resource_type('Comment'), 'id': str(self.second_comment.id)}, ] } @@ -201,7 +201,7 @@ def test_delete_relationship_overriding_with_none(self): assert response.data['author'] == None def test_delete_to_many_relationship_with_no_change(self): - url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id) + url = '/entries/{}/relationships/comments'.format(self.first_entry.id) request_data = { 'data': [{'type': format_resource_type('Comment'), 'id': str(self.second_comment.id)}, ] } @@ -209,7 +209,7 @@ def test_delete_to_many_relationship_with_no_change(self): assert response.status_code == 204, response.content.decode() def test_delete_one_to_many_relationship_with_not_null_constraint(self): - url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id) + url = '/entries/{}/relationships/comments'.format(self.first_entry.id) request_data = { 'data': [{'type': format_resource_type('Comment'), 'id': str(self.first_comment.id)}, ] } diff --git a/example/tests/unit/test_renderers.py b/example/tests/unit/test_renderers.py index b3aeef78..eff616e1 100644 --- a/example/tests/unit/test_renderers.py +++ b/example/tests/unit/test_renderers.py @@ -16,7 +16,7 @@ class DummyTestSerializer(serializers.ModelSerializer): a single embedded relation ''' related_models = RelatedModelSerializer( - source='comment_set', many=True, read_only=True) + source='comments', many=True, read_only=True) class Meta: model = Entry diff --git a/example/views.py b/example/views.py index 988cda66..54330fc5 100644 --- a/example/views.py +++ b/example/views.py @@ -5,7 +5,7 @@ import rest_framework_json_api.metadata import rest_framework_json_api.parsers import rest_framework_json_api.renderers -from rest_framework_json_api.views import RelationshipView +from rest_framework_json_api.views import ModelViewSet, RelationshipView from example.models import Blog, Entry, Author, Comment from example.serializers import ( BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer) @@ -15,12 +15,12 @@ HTTP_422_UNPROCESSABLE_ENTITY = 422 -class BlogViewSet(viewsets.ModelViewSet): +class BlogViewSet(ModelViewSet): queryset = Blog.objects.all() serializer_class = BlogSerializer -class JsonApiViewSet(viewsets.ModelViewSet): +class JsonApiViewSet(ModelViewSet): """ This is an example on how to configure DRF-jsonapi from within a class. It allows using DRF-jsonapi alongside @@ -54,7 +54,7 @@ class BlogCustomViewSet(JsonApiViewSet): serializer_class = BlogSerializer -class EntryViewSet(viewsets.ModelViewSet): +class EntryViewSet(ModelViewSet): queryset = Entry.objects.all() resource_name = 'posts' @@ -62,12 +62,12 @@ def get_serializer_class(self): return EntrySerializer -class AuthorViewSet(viewsets.ModelViewSet): +class AuthorViewSet(ModelViewSet): queryset = Author.objects.all() serializer_class = AuthorSerializer -class CommentViewSet(viewsets.ModelViewSet): +class CommentViewSet(ModelViewSet): queryset = Comment.objects.all() serializer_class = CommentSerializer diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index ee368851..ae75035e 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -6,13 +6,17 @@ from django.db.models.manager import Manager if django.VERSION < (1, 9): from django.db.models.fields.related import ( - ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor, + ForeignRelatedObjectsDescriptor as ReverseManyToOneDescriptor, ManyRelatedObjectsDescriptor as ManyToManyDescriptor, + ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor, + SingleRelatedObjectDescriptor as ReverseOneToOneDescriptor, ) else: from django.db.models.fields.related_descriptors import ( ForwardManyToOneDescriptor, ManyToManyDescriptor, + ReverseManyToOneDescriptor, + ReverseOneToOneDescriptor, ) from rest_framework import generics, viewsets from rest_framework.response import Response @@ -32,7 +36,7 @@ class ModelViewSet(viewsets.ModelViewSet): def get_queryset(self, *args, **kwargs): - qs = super().get_queryset(*args, **kwargs) + qs = super(ModelViewSet, self).get_queryset(*args, **kwargs) included_resources = get_included_resources(self.request) for included in included_resources: @@ -44,16 +48,30 @@ def get_queryset(self, *args, **kwargs): break field = getattr(level_model, level) field_class = field.__class__ - if not ( + + is_forward_relation = ( issubclass(field_class, ForwardManyToOneDescriptor) or issubclass(field_class, ManyToManyDescriptor) - ): + ) + is_reverse_relation = ( + issubclass(field_class, ReverseManyToOneDescriptor) + or issubclass(field_class, ReverseOneToOneDescriptor) + ) + if not (is_forward_relation or is_reverse_relation): break if level == levels[-1]: included_model = field else: - level_model = field.get_queryset().model + if django.VERSION < (1, 9): + model_field = field.related + else: + model_field = field.field + + if is_forward_relation: + level_model = model_field.related_model + else: + level_model = model_field.model if included_model is not None: qs = qs.prefetch_related(included.replace('.', '__'))