|
8 | 8 | from rest_framework.response import Response
|
9 | 9 | from rest_framework.reverse import reverse
|
10 | 10 | from rest_framework.serializers import Serializer
|
11 |
| - |
12 | 11 | from rest_framework_json_api.exceptions import Conflict
|
13 | 12 | from rest_framework_json_api.serializers import ResourceIdentifierObjectSerializer
|
14 | 13 | from rest_framework_json_api.utils import (
|
|
39 | 38 | )
|
40 | 39 |
|
41 | 40 |
|
42 |
| -class ModelViewSet(viewsets.ModelViewSet): |
| 41 | +class PrefetchForIncludesHelperMixin(object): |
| 42 | + def get_queryset(self): |
| 43 | + """ This viewset provides a helper attribute to prefetch related models |
| 44 | + based on the include specified in the URL. |
| 45 | +
|
| 46 | + __all__ can be used to specify a prefetch which should be done regardless of the include |
| 47 | +
|
| 48 | + @example |
| 49 | + # When MyViewSet is called with ?include=author it will prefetch author and authorbio |
| 50 | + class MyViewSet(viewsets.ModelViewSet): |
| 51 | + queryset = Book.objects.all() |
| 52 | + prefetch_for_includes = { |
| 53 | + '__all__': [], |
| 54 | + 'author': ['author', 'author__authorbio'] |
| 55 | + 'category.section': ['category'] |
| 56 | + } |
| 57 | + """ |
| 58 | + qs = super(PrefetchForIncludesHelperMixin, self).get_queryset() |
| 59 | + if not hasattr(self, 'prefetch_for_includes'): |
| 60 | + return qs |
| 61 | + |
| 62 | + includes = self.request.GET.get('include', '').split(',') |
| 63 | + for inc in includes + ['__all__']: |
| 64 | + prefetches = self.prefetch_for_includes.get(inc) |
| 65 | + if prefetches: |
| 66 | + qs = qs.prefetch_related(*prefetches) |
| 67 | + |
| 68 | + return qs |
| 69 | + |
| 70 | + |
| 71 | +class AutoPrefetchMixin(object): |
43 | 72 | def get_queryset(self, *args, **kwargs):
|
44 |
| - qs = super(ModelViewSet, self).get_queryset(*args, **kwargs) |
| 73 | + """ This mixin adds automatic prefetching for OneToOne and ManyToMany fields. """ |
| 74 | + qs = super(AutoPrefetchMixin, self).get_queryset(*args, **kwargs) |
45 | 75 | included_resources = get_included_resources(self.request)
|
46 | 76 |
|
47 | 77 | for included in included_resources:
|
@@ -84,6 +114,10 @@ def get_queryset(self, *args, **kwargs):
|
84 | 114 | return qs
|
85 | 115 |
|
86 | 116 |
|
| 117 | +class ModelViewSet(AutoPrefetchMixin, PrefetchForIncludesHelperMixin, viewsets.ModelViewSet): |
| 118 | + pass |
| 119 | + |
| 120 | + |
87 | 121 | class RelationshipView(generics.GenericAPIView):
|
88 | 122 | serializer_class = ResourceIdentifierObjectSerializer
|
89 | 123 | self_link_view_name = None
|
|
0 commit comments