Skip to content

Commit ae3c12f

Browse files
committed
Add nested prefetching & ondemand logic
1 parent d756fd7 commit ae3c12f

File tree

5 files changed

+253
-57
lines changed

5 files changed

+253
-57
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ any parts of the framework not mentioned in the documentation should generally b
1313
### Changed
1414

1515
* Moved resolving of `included_serialzers` and `related_serializers` classes to serializer's meta class.
16+
* `AutoPrefetchMixin` updated to be more clever about how relationships are prefetched, with recursion all the way down.
17+
* Expensive reverse relations are now automatically excluded from queries that don't explicitly name them in sparsefieldsets. Set `INCLUDE_EXPENSVE_FIELDS` to revert to old behaviour.
1618
* Removed `PreloadIncludesMixin`, as the logic did not work when nesting includes, and the laborious effort needed in its manual config was unnecessary.
1719

1820
### Deprecated

rest_framework_json_api/serializers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323
from rest_framework_json_api.exceptions import Conflict
2424
from rest_framework_json_api.relations import ResourceRelatedField
2525
from rest_framework_json_api.utils import (
26+
get_expensive_relational_fields,
2627
get_included_resources,
2728
get_resource_type_from_instance,
2829
get_resource_type_from_model,
2930
get_resource_type_from_serializer,
3031
)
3132

33+
from .settings import json_api_settings
34+
3235

3336
class ResourceIdentifierObjectSerializer(BaseSerializer):
3437
default_error_messages = {
@@ -153,6 +156,43 @@ def validate_path(serializer_class, field_path, path):
153156
super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs)
154157

155158

159+
class OnDemandFieldsMixin:
160+
"""
161+
Automatically certain fields from the serializer that have been deemed expensive.
162+
In order to see these fields, the client must explcitly request them.
163+
"""
164+
165+
def __init__(self, *args, **kwargs):
166+
super().__init__(*args, **kwargs)
167+
168+
# Pop any fields off the serializer that shouldn't come through.
169+
for field in self.get_excluded_ondemand_fields():
170+
self.fields.pop(field, None)
171+
172+
def get_excluded_ondemand_fields(self) -> list[str]:
173+
"""
174+
Determine which fields should be popped off if not explicitly asked for.
175+
Will not nominate any fields that have been designated as `demanded_fields` in context.
176+
Ondemand fields are determined in like so:
177+
- Fields that we automatically determine to be expensive, and thus automatically remove
178+
from the default offering. Currently such fields are M2Ms and reverse FKs.
179+
"""
180+
if json_api_settings.INCLUDE_EXPENSVE_FIELDS:
181+
return set()
182+
183+
# If we've instantiated the serializer ourselves, we'll have fed `demanded_fields` into its context.
184+
# If it's happened as part of drf render internals, then we have a fallback where the view
185+
# has provided the entire sparsefields context for us to pick through.
186+
if 'demanded_fields' in self.context:
187+
demanded_fields = set(self.context.get('demanded_fields'))
188+
else:
189+
resource_name = get_resource_type_from_serializer(type(self))
190+
demanded_fields = set(self.context.get('all_sparsefields', {}).get(resource_name, []))
191+
192+
# We only want to exclude those ondemand fields that haven't been explicitly requested.
193+
return set(get_expensive_relational_fields(type(self))) - set(demanded_fields)
194+
195+
156196
class LazySerializersDict(Mapping):
157197
"""
158198
A dictionary of serializers which lazily import dotted class path and self.
@@ -207,6 +247,7 @@ def __new__(cls, name, bases, attrs):
207247
# If user imports serializer from here we can catch class definition and check
208248
# nested serializers for depricated use.
209249
class Serializer(
250+
OnDemandFieldsMixin,
210251
IncludedResourcesValidationMixin,
211252
SparseFieldsetsMixin,
212253
Serializer,
@@ -230,6 +271,7 @@ class Serializer(
230271

231272

232273
class HyperlinkedModelSerializer(
274+
OnDemandFieldsMixin,
233275
IncludedResourcesValidationMixin,
234276
SparseFieldsetsMixin,
235277
HyperlinkedModelSerializer,
@@ -250,6 +292,7 @@ class HyperlinkedModelSerializer(
250292

251293

252294
class ModelSerializer(
295+
OnDemandFieldsMixin,
253296
IncludedResourcesValidationMixin,
254297
SparseFieldsetsMixin,
255298
ModelSerializer,

rest_framework_json_api/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"FORMAT_RELATED_LINKS": False,
1616
"PLURALIZE_TYPES": False,
1717
"UNIFORM_EXCEPTIONS": False,
18+
"INCLUDE_EXPENSVE_FIELDS": False,
1819
}
1920

2021

rest_framework_json_api/utils.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
import inspect
2+
import logging
23
import operator
34
import warnings
45
from collections import OrderedDict
56

67
import inflection
78
from django.conf import settings
8-
from django.db.models import Manager
9+
from django.db.models import Manager, Prefetch
910
from django.db.models.fields.related_descriptors import (
11+
ForwardManyToOneDescriptor,
1012
ManyToManyDescriptor,
1113
ReverseManyToOneDescriptor,
14+
ReverseOneToOneDescriptor,
1215
)
16+
from django.db.models.query import QuerySet
1317
from django.http import Http404
1418
from django.utils import encoding
1519
from django.utils.translation import gettext_lazy as _
1620
from rest_framework import exceptions
1721
from rest_framework.exceptions import APIException
22+
from rest_framework.relations import RelatedField
23+
from rest_framework.request import Request
24+
25+
from rest_framework_json_api.serializers import ModelSerializer, ValidationError
1826

1927
from .settings import json_api_settings
2028

29+
logger = logging.getLogger(__name__)
30+
2131
# Generic relation descriptor from django.contrib.contenttypes.
2232
if "django.contrib.contenttypes" not in settings.INSTALLED_APPS: # pragma: no cover
2333
# Target application does not use contenttypes. Importing would cause errors.
@@ -472,3 +482,144 @@ def format_errors(data):
472482
if len(data) > 1 and isinstance(data, list):
473483
data.sort(key=lambda x: x.get("source", {}).get("pointer", ""))
474484
return {"errors": data}
485+
486+
487+
def get_expensive_relational_fields(serializer_class: ModelSerializer) -> list[str]:
488+
"""
489+
We define 'expensive' as relational fields on the serializer that don't correspond to a
490+
forward relation on the model.
491+
"""
492+
return [
493+
field
494+
for field in getattr(serializer_class, 'included_serializers', {})
495+
if not isinstance(getattr(serializer_class.Meta.model, field, None), ForwardManyToOneDescriptor)
496+
]
497+
498+
499+
def get_cheap_relational_fields(serializer_class: ModelSerializer) -> list[str]:
500+
"""
501+
We define 'cheap' as relational fields on the serializer that _do_ correspond to a
502+
forward relation on the model.
503+
"""
504+
return [
505+
field
506+
for field in getattr(serializer_class, 'included_serializers', {})
507+
if isinstance(getattr(serializer_class.Meta.model, field, None), ForwardManyToOneDescriptor)
508+
]
509+
510+
511+
def get_queryset_for_field(field: RelatedField) -> QuerySet:
512+
model_field_descriptor = getattr(field.parent.Meta.model, field.field_name)
513+
# NOTE: Important to check in this order, as some of these classes are ancestors of one
514+
# another (ie `ManyToManyDescriptor` subclasses `ReverseManyToOneDescriptor`)
515+
if isinstance(model_field_descriptor, ForwardManyToOneDescriptor):
516+
if (qs := field.queryset) is None:
517+
qs = model_field_descriptor.field.related_model._default_manager
518+
elif isinstance(model_field_descriptor, ManyToManyDescriptor):
519+
qs = field.child_relation.queryset
520+
elif isinstance(model_field_descriptor, ReverseManyToOneDescriptor):
521+
if (qs := field.child_relation.queryset) is None:
522+
qs = model_field_descriptor.field.model._default_manager
523+
elif isinstance(model_field_descriptor, ReverseOneToOneDescriptor):
524+
qs = model_field_descriptor.get_queryset()
525+
526+
# Note: We call `.all()` before returning, as `_default_manager` may on occasion return a Manager
527+
# instance rather than a QuerySet, and we strictly want to be working with the latter.
528+
# (_default_manager is being used both direclty by us here, and by drf behind the scenes)
529+
# See: https://github.com/encode/django-rest-framework/blame/master/rest_framework/utils/field_mapping.py#L243
530+
return qs.all()
531+
532+
533+
def add_nested_prefetches_to_qs(
534+
serializer_class: ModelSerializer,
535+
qs: QuerySet,
536+
request: Request,
537+
sparsefields: dict[str, list[str]],
538+
includes: dict, # TODO: Define typing as recursive once supported.
539+
select_related: str = '',
540+
) -> QuerySet:
541+
"""
542+
Prefetch all required data onto the supplied queryset, calling this method recursively for child
543+
serializers where needed.
544+
There is some added built-in optimisation here, attempting to opt for select_related calls over
545+
prefetches where possible -- it's only possible if the child serializers are interested
546+
exclusively in select_relating also. This is controlled with the `select_related` param.
547+
If `select_related` comes through, will attempt to instead build further onto this and return
548+
a dundered list of strings for the caller to use in a select_related call. If that fails,
549+
returns a qs as normal.
550+
"""
551+
# Determine fields that'll be returned by this serializer.
552+
resource_name = get_resource_type_from_serializer(serializer_class)
553+
logger.debug(f'ADDING NESTED PREFETCHES FOR: {resource_name}')
554+
dummy_serializer = serializer_class(context={'request': request, 'demanded_fields': sparsefields.get(resource_name, [])})
555+
requested_fields = dummy_serializer.fields.keys()
556+
557+
# Ensure any requested includes are in the fields list, else error loudly!
558+
if not includes.keys() <= requested_fields:
559+
errors = {f'{resource_name}.{field}': 'Field marked as include but not requested for serialization.' for field in includes.keys() - requested_fields}
560+
raise ValidationError(errors)
561+
562+
included_serializers = get_included_serializers(serializer_class)
563+
564+
# Iterate over all expensive relations and prefetch_related where needed.
565+
for field in get_expensive_relational_fields(serializer_class):
566+
if field in requested_fields:
567+
logger.debug(f'EXPENSIVE_FIELD: {field}')
568+
select_related = '' # wipe, cannot be used. :(
569+
if not hasattr(qs.model, field):
570+
# We might fall into here if, for example, there's an expensive
571+
# SerializerMethodResourceRelatedField defined.
572+
continue
573+
if field in includes:
574+
logger.debug('- PREFETCHING DEEP')
575+
# Prefetch and recurse.
576+
child_serializer_class = included_serializers[field]
577+
prefetch_qs = add_nested_prefetches_to_qs(
578+
child_serializer_class,
579+
get_queryset_for_field(dummy_serializer.fields[field]),
580+
request=request,
581+
sparsefields=sparsefields,
582+
includes=includes[field],
583+
)
584+
qs = qs.prefetch_related(Prefetch(field, prefetch_qs))
585+
else:
586+
logger.debug('- PREFETCHING SHALLOW')
587+
# Prefetch "shallowly"; we only care about ids.
588+
qs = qs.prefetch_related(field) # TODO: Still use ResourceRelatedField.qs if present!
589+
590+
# Iterate over all cheap (forward) relations and select_related (or prefetch) where needed.
591+
new_select_related = [select_related]
592+
for field in get_cheap_relational_fields(serializer_class):
593+
if field in requested_fields:
594+
logger.debug(f'CHEAP_FIELD: {field}')
595+
if field in includes:
596+
logger.debug('- present in includes')
597+
# Recurse and see if we get a prefetch qs back, or a select_related string.
598+
child_serializer_class = included_serializers[field]
599+
prefetch_qs_or_select_related_str = add_nested_prefetches_to_qs(
600+
child_serializer_class,
601+
get_queryset_for_field(dummy_serializer.fields[field]),
602+
request=request,
603+
sparsefields=sparsefields,
604+
includes=includes[field],
605+
select_related=field,
606+
)
607+
if isinstance(prefetch_qs_or_select_related_str, list):
608+
logger.debug(f'SELECTING RELATED: {prefetch_qs_or_select_related_str}')
609+
# Prefetch has come back as a list of (dundered) strings.
610+
# We append onto existing select_related string, to potentially pass back up
611+
# and also feed it directly into a select_related call in case the former
612+
# falls through.
613+
if select_related:
614+
for sr in prefetch_qs_or_select_related_str:
615+
new_select_related.append(f'{select_related}__{sr}')
616+
qs = qs.select_related(*prefetch_qs_or_select_related_str)
617+
else:
618+
# Select related option fell through, we need to do a prefetch. :(
619+
logger.debug(f'PREFETCHING RELATED: {field}')
620+
select_related = ''
621+
qs = qs.prefetch_related(Prefetch(field, prefetch_qs_or_select_related_str))
622+
623+
if select_related:
624+
return new_select_related
625+
return qs

0 commit comments

Comments
 (0)