diff --git a/CHANGELOG.md b/CHANGELOG.md index 32799366..c62b5383 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ any parts of the framework not mentioned in the documentation should generally b OrderViewSet.as_view({'get': 'retrieve_related'}), name='order-related'), ``` +* Ensure default `included_resources` are considered when calculating prefetches. ### Deprecated diff --git a/example/tests/test_performance.py b/example/tests/test_performance.py index e42afada..cae11ed4 100644 --- a/example/tests/test_performance.py +++ b/example/tests/test_performance.py @@ -1,7 +1,7 @@ from django.utils import timezone from rest_framework.test import APITestCase -from example.factories import CommentFactory +from example.factories import CommentFactory, EntryFactory from example.models import Author, Blog, Comment, Entry @@ -36,6 +36,7 @@ def setUp(self): ) self.comment = Comment.objects.create(entry=self.first_entry) CommentFactory.create_batch(50) + EntryFactory.create_batch(50) def test_query_count_no_includes(self): """We expect a simple list view to issue only two queries. @@ -49,7 +50,7 @@ def test_query_count_no_includes(self): self.assertEqual(len(response.data["results"]), 25) def test_query_count_include_author(self): - """We expect a list view with an include have three queries: + """We expect a list view with an include have five queries: 1. Primary resource COUNT query 2. Primary resource SELECT @@ -70,3 +71,16 @@ def test_query_select_related_entry(self): with self.assertNumQueries(2): response = self.client.get("/comments?include=writer&page[size]=25") self.assertEqual(len(response.data["results"]), 25) + + def test_query_prefetch_uses_included_resources(self): + """We expect a list view with `included_resources` to have three queries: + + 1. Primary resource COUNT query + 2. Primary resource SELECT + 3. Comments prefetched + """ + with self.assertNumQueries(3): + response = self.client.get( + "/entries?fields[entries]=comments&page[size]=25" + ) + self.assertEqual(len(response.data["results"]), 25) diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index a73a5d47..5ca773d0 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -137,7 +137,7 @@ def validate_path(serializer_class, field_path, path): validate_path(this_included_serializer, new_included_field_path, path) if request and view: - included_resources = get_included_resources(request) + included_resources = get_included_resources(request, self) for included_field_name in included_resources: included_field_path = included_field_name.split(".") if "related_field" in view.kwargs: diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index 84ec509e..8369cec9 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -63,7 +63,9 @@ def get_prefetch_related(self, include): def get_queryset(self, *args, **kwargs): qs = super(PreloadIncludesMixin, self).get_queryset(*args, **kwargs) - included_resources = get_included_resources(self.request) + included_resources = get_included_resources( + self.request, self.get_serializer_class() + ) for included in included_resources + ["__all__"]: select_related = self.get_select_related(included) @@ -82,7 +84,9 @@ def get_queryset(self, *args, **kwargs): """ This mixin adds automatic prefetching for OneToOne and ManyToMany fields. """ qs = super(AutoPrefetchMixin, self).get_queryset(*args, **kwargs) - included_resources = get_included_resources(self.request) + included_resources = get_included_resources( + self.request, self.get_serializer_class() + ) for included in included_resources + ["__all__"]: # If include was not defined, trying to resolve it automatically