Skip to content

Fix/auto prefetch with m2m #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __str__(self):

@python_2_unicode_compatible
class Comment(BaseModel):
entry = models.ForeignKey(Entry)
entry = models.ForeignKey(Entry, related_name='comments')
body = models.TextField()
author = models.ForeignKey(
Author,
Expand Down
2 changes: 1 addition & 1 deletion example/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, *args, **kwargs):
body_format = serializers.SerializerMethodField()
# many related from model
comments = relations.ResourceRelatedField(
source='comment_set', many=True, read_only=True)
many=True, read_only=True)
# many related from serializer
suggested = relations.SerializerMethodResourceRelatedField(
source='get_suggested', model=Entry, many=True, read_only=True,
Expand Down
16 changes: 8 additions & 8 deletions example/tests/integration/test_includes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_included_data_on_list(multiple_entries, client, query='?include=comment
assert [x.get('type') for x in included] == ['comments', 'comments'], 'List included types are incorrect'

comment_count = len([resource for resource in included if resource["type"] == "comments"])
expected_comment_count = sum([entry.comment_set.count() for entry in multiple_entries])
expected_comment_count = sum([entry.comments.count() for entry in multiple_entries])
assert comment_count == expected_comment_count, 'List comment count is incorrect'


Expand All @@ -33,7 +33,7 @@ def test_included_data_on_detail(single_entry, client, query='?include=comments'
assert [x.get('type') for x in included] == ['comments'], 'Detail included types are incorrect'

comment_count = len([resource for resource in included if resource["type"] == "comments"])
expected_comment_count = single_entry.comment_set.count()
expected_comment_count = single_entry.comments.count()
assert comment_count == expected_comment_count, 'Detail comment count is incorrect'


Expand Down Expand Up @@ -81,16 +81,16 @@ def test_deep_included_data_on_list(multiple_entries, client):
], 'List included types are incorrect'

comment_count = len([resource for resource in included if resource["type"] == "comments"])
expected_comment_count = sum([entry.comment_set.count() for entry in multiple_entries])
expected_comment_count = sum([entry.comments.count() for entry in multiple_entries])
assert comment_count == expected_comment_count, 'List comment count is incorrect'

author_count = len([resource for resource in included if resource["type"] == "authors"])
expected_author_count = sum(
[entry.comment_set.filter(author__isnull=False).count() for entry in multiple_entries])
[entry.comments.filter(author__isnull=False).count() for entry in multiple_entries])
assert author_count == expected_author_count, 'List author count is incorrect'

author_bio_count = len([resource for resource in included if resource["type"] == "authorBios"])
expected_author_bio_count = sum([entry.comment_set.filter(
expected_author_bio_count = sum([entry.comments.filter(
author__bio__isnull=False).count() for entry in multiple_entries])
assert author_bio_count == expected_author_bio_count, 'List author bio count is incorrect'

Expand All @@ -107,7 +107,7 @@ def test_deep_included_data_on_list(multiple_entries, client):
author_count = len([resource for resource in included if resource["type"] == "authors"])
expected_author_count = sum(
[entry.authors.count() for entry in multiple_entries] +
[entry.comment_set.filter(author__isnull=False).count() for entry in multiple_entries])
[entry.comments.filter(author__isnull=False).count() for entry in multiple_entries])
assert author_count == expected_author_count, 'List author count is incorrect'


Expand All @@ -122,9 +122,9 @@ def test_deep_included_data_on_detail(single_entry, client):
'Detail included types are incorrect'

comment_count = len([resource for resource in included if resource["type"] == "comments"])
expected_comment_count = single_entry.comment_set.count()
expected_comment_count = single_entry.comments.count()
assert comment_count == expected_comment_count, 'Detail comment count is incorrect'

author_bio_count = len([resource for resource in included if resource["type"] == "authorBios"])
expected_author_bio_count = single_entry.comment_set.filter(author__bio__isnull=False).count()
expected_author_bio_count = single_entry.comments.filter(author__bio__isnull=False).count()
assert author_bio_count == expected_author_bio_count, 'Detail author bio count is incorrect'
10 changes: 5 additions & 5 deletions example/tests/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ def test_deserialize_many_to_many_relation(self):
author_pks = Author.objects.values_list('pk', flat=True)
authors = [{'type': type_string, 'id': pk} for pk in author_pks]

serializer = EntryModelSerializer(data={'authors': authors, 'comment_set': []})
serializer = EntryModelSerializer(data={'authors': authors, 'comments': []})

self.assertTrue(serializer.is_valid())
self.assertEqual(len(serializer.validated_data['authors']), Author.objects.count())
for author in serializer.validated_data['authors']:
self.assertIsInstance(author, Author)

def test_read_only(self):
serializer = EntryModelSerializer(data={'authors': [], 'comment_set': [{'type': 'Comments', 'id': 2}]})
serializer = EntryModelSerializer(data={'authors': [], 'comments': [{'type': 'Comments', 'id': 2}]})
serializer.is_valid(raise_exception=True)
self.assertNotIn('comment_set', serializer.validated_data)
self.assertNotIn('comments', serializer.validated_data)

def test_invalid_resource_id_object(self):
comment = {'body': 'testing 123', 'entry': {'type': 'entry'}, 'author': {'id': '5'}}
Expand All @@ -136,8 +136,8 @@ class EntryFKSerializer(serializers.Serializer):

class EntryModelSerializer(serializers.ModelSerializer):
authors = ResourceRelatedField(many=True, queryset=Author.objects)
comment_set = ResourceRelatedField(many=True, read_only=True)
comments = ResourceRelatedField(many=True, read_only=True)

class Meta:
model = Entry
fields = ('authors', 'comment_set')
fields = ('authors', 'comments')
8 changes: 4 additions & 4 deletions example/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ def test_post_to_one_relationship_should_fail(self):
assert response.status_code == 405, response.content.decode()

def test_post_to_many_relationship_with_no_change(self):
url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id)
url = '/entries/{}/relationships/comments'.format(self.first_entry.id)
request_data = {
'data': [{'type': format_resource_type('Comment'), 'id': str(self.first_comment.id)}, ]
}
response = self.client.post(url, data=json.dumps(request_data), content_type='application/vnd.api+json')
assert response.status_code == 204, response.content.decode()

def test_post_to_many_relationship_with_change(self):
url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id)
url = '/entries/{}/relationships/comments'.format(self.first_entry.id)
request_data = {
'data': [{'type': format_resource_type('Comment'), 'id': str(self.second_comment.id)}, ]
}
Expand Down Expand Up @@ -201,15 +201,15 @@ def test_delete_relationship_overriding_with_none(self):
assert response.data['author'] == None

def test_delete_to_many_relationship_with_no_change(self):
url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id)
url = '/entries/{}/relationships/comments'.format(self.first_entry.id)
request_data = {
'data': [{'type': format_resource_type('Comment'), 'id': str(self.second_comment.id)}, ]
}
response = self.client.delete(url, data=json.dumps(request_data), content_type='application/vnd.api+json')
assert response.status_code == 204, response.content.decode()

def test_delete_one_to_many_relationship_with_not_null_constraint(self):
url = '/entries/{}/relationships/comment_set'.format(self.first_entry.id)
url = '/entries/{}/relationships/comments'.format(self.first_entry.id)
request_data = {
'data': [{'type': format_resource_type('Comment'), 'id': str(self.first_comment.id)}, ]
}
Expand Down
2 changes: 1 addition & 1 deletion example/tests/unit/test_renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DummyTestSerializer(serializers.ModelSerializer):
a single embedded relation
'''
related_models = RelatedModelSerializer(
source='comment_set', many=True, read_only=True)
source='comments', many=True, read_only=True)

class Meta:
model = Entry
Expand Down
12 changes: 6 additions & 6 deletions example/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rest_framework_json_api.metadata
import rest_framework_json_api.parsers
import rest_framework_json_api.renderers
from rest_framework_json_api.views import RelationshipView
from rest_framework_json_api.views import ModelViewSet, RelationshipView
from example.models import Blog, Entry, Author, Comment
from example.serializers import (
BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer)
Expand All @@ -15,12 +15,12 @@
HTTP_422_UNPROCESSABLE_ENTITY = 422


class BlogViewSet(viewsets.ModelViewSet):
class BlogViewSet(ModelViewSet):
queryset = Blog.objects.all()
serializer_class = BlogSerializer


class JsonApiViewSet(viewsets.ModelViewSet):
class JsonApiViewSet(ModelViewSet):
"""
This is an example on how to configure DRF-jsonapi from
within a class. It allows using DRF-jsonapi alongside
Expand Down Expand Up @@ -54,20 +54,20 @@ class BlogCustomViewSet(JsonApiViewSet):
serializer_class = BlogSerializer


class EntryViewSet(viewsets.ModelViewSet):
class EntryViewSet(ModelViewSet):
queryset = Entry.objects.all()
resource_name = 'posts'

def get_serializer_class(self):
return EntrySerializer


class AuthorViewSet(viewsets.ModelViewSet):
class AuthorViewSet(ModelViewSet):
queryset = Author.objects.all()
serializer_class = AuthorSerializer


class CommentViewSet(viewsets.ModelViewSet):
class CommentViewSet(ModelViewSet):
queryset = Comment.objects.all()
serializer_class = CommentSerializer

Expand Down
28 changes: 23 additions & 5 deletions rest_framework_json_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from django.db.models.manager import Manager
if django.VERSION < (1, 9):
from django.db.models.fields.related import (
ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor,
ForeignRelatedObjectsDescriptor as ReverseManyToOneDescriptor,
ManyRelatedObjectsDescriptor as ManyToManyDescriptor,
ReverseSingleRelatedObjectDescriptor as ForwardManyToOneDescriptor,
SingleRelatedObjectDescriptor as ReverseOneToOneDescriptor,
)
else:
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ManyToManyDescriptor,
ReverseManyToOneDescriptor,
ReverseOneToOneDescriptor,
)
from rest_framework import generics, viewsets
from rest_framework.response import Response
Expand All @@ -32,7 +36,7 @@

class ModelViewSet(viewsets.ModelViewSet):
def get_queryset(self, *args, **kwargs):
qs = super().get_queryset(*args, **kwargs)
qs = super(ModelViewSet, self).get_queryset(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, nice! Thanks for this one.

included_resources = get_included_resources(self.request)

for included in included_resources:
Expand All @@ -44,16 +48,30 @@ def get_queryset(self, *args, **kwargs):
break
field = getattr(level_model, level)
field_class = field.__class__
if not (

is_forward_relation = (
issubclass(field_class, ForwardManyToOneDescriptor)
or issubclass(field_class, ManyToManyDescriptor)
):
)
is_reverse_relation = (
issubclass(field_class, ReverseManyToOneDescriptor)
or issubclass(field_class, ReverseOneToOneDescriptor)
)
if not (is_forward_relation or is_reverse_relation):
break

if level == levels[-1]:
included_model = field
else:
level_model = field.get_queryset().model
if django.VERSION < (1, 9):
model_field = field.related
else:
model_field = field.field

if is_forward_relation:
level_model = model_field.related_model
else:
level_model = model_field.model

if included_model is not None:
qs = qs.prefetch_related(included.replace('.', '__'))
Expand Down