From 1d84c80d8c5b5e1d5d13db814daa3ed356c198bd Mon Sep 17 00:00:00 2001 From: Aidan Lister Date: Thu, 20 Jul 2017 22:29:33 +1000 Subject: [PATCH] Add a helper for specifying a prefetch_for_related attribute on your viewset to help with prefetching includes. --- rest_framework_json_api/views.py | 40 +++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index b8e81e6a..c99ddad4 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -8,7 +8,6 @@ from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.serializers import Serializer - from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer from rest_framework_json_api.utils import ( @@ -39,9 +38,40 @@ ) -class ModelViewSet(viewsets.ModelViewSet): +class PrefetchForIncludesHelperMixin(object): + def get_queryset(self): + """ This viewset provides a helper attribute to prefetch related models + based on the include specified in the URL. + + __all__ can be used to specify a prefetch which should be done regardless of the include + + @example + # When MyViewSet is called with ?include=author it will prefetch author and authorbio + class MyViewSet(viewsets.ModelViewSet): + queryset = Book.objects.all() + prefetch_for_includes = { + '__all__': [], + 'author': ['author', 'author__authorbio'] + 'category.section': ['category'] + } + """ + qs = super(PrefetchForIncludesHelperMixin, self).get_queryset() + if not hasattr(self, 'prefetch_for_includes'): + return qs + + includes = self.request.GET.get('include', '').split(',') + for inc in includes + ['__all__']: + prefetches = self.prefetch_for_includes.get(inc) + if prefetches: + qs = qs.prefetch_related(*prefetches) + + return qs + + +class AutoPrefetchMixin(object): def get_queryset(self, *args, **kwargs): - qs = super(ModelViewSet, self).get_queryset(*args, **kwargs) + """ This mixin adds automatic prefetching for OneToOne and ManyToMany fields. """ + qs = super(AutoPrefetchMixin, self).get_queryset(*args, **kwargs) included_resources = get_included_resources(self.request) for included in included_resources: @@ -84,6 +114,10 @@ def get_queryset(self, *args, **kwargs): return qs +class ModelViewSet(AutoPrefetchMixin, PrefetchForIncludesHelperMixin, viewsets.ModelViewSet): + pass + + class RelationshipView(generics.GenericAPIView): serializer_class = ResourceIdentifierObjectSerializer self_link_view_name = None