Skip to content

Commit 1d84c80

Browse files
committed
Add a helper for specifying a prefetch_for_related attribute on your viewset to help with prefetching includes.
1 parent 1659d8c commit 1d84c80

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

rest_framework_json_api/views.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from rest_framework.response import Response
99
from rest_framework.reverse import reverse
1010
from rest_framework.serializers import Serializer
11-
1211
from rest_framework_json_api.exceptions import Conflict
1312
from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer
1413
from rest_framework_json_api.utils import (
@@ -39,9 +38,40 @@
3938
)
4039

4140

42-
class ModelViewSet(viewsets.ModelViewSet):
41+
class PrefetchForIncludesHelperMixin(object):
42+
def get_queryset(self):
43+
""" This viewset provides a helper attribute to prefetch related models
44+
based on the include specified in the URL.
45+
46+
__all__ can be used to specify a prefetch which should be done regardless of the include
47+
48+
@example
49+
# When MyViewSet is called with ?include=author it will prefetch author and authorbio
50+
class MyViewSet(viewsets.ModelViewSet):
51+
queryset = Book.objects.all()
52+
prefetch_for_includes = {
53+
'__all__': [],
54+
'author': ['author', 'author__authorbio']
55+
'category.section': ['category']
56+
}
57+
"""
58+
qs = super(PrefetchForIncludesHelperMixin, self).get_queryset()
59+
if not hasattr(self, 'prefetch_for_includes'):
60+
return qs
61+
62+
includes = self.request.GET.get('include', '').split(',')
63+
for inc in includes + ['__all__']:
64+
prefetches = self.prefetch_for_includes.get(inc)
65+
if prefetches:
66+
qs = qs.prefetch_related(*prefetches)
67+
68+
return qs
69+
70+
71+
class AutoPrefetchMixin(object):
4372
def get_queryset(self, *args, **kwargs):
44-
qs = super(ModelViewSet, self).get_queryset(*args, **kwargs)
73+
""" This mixin adds automatic prefetching for OneToOne and ManyToMany fields. """
74+
qs = super(AutoPrefetchMixin, self).get_queryset(*args, **kwargs)
4575
included_resources = get_included_resources(self.request)
4676

4777
for included in included_resources:
@@ -84,6 +114,10 @@ def get_queryset(self, *args, **kwargs):
84114
return qs
85115

86116

117+
class ModelViewSet(AutoPrefetchMixin, PrefetchForIncludesHelperMixin, viewsets.ModelViewSet):
118+
pass
119+
120+
87121
class RelationshipView(generics.GenericAPIView):
88122
serializer_class = ResourceIdentifierObjectSerializer
89123
self_link_view_name = None

0 commit comments

Comments
 (0)