|
1 | 1 | import inspect
|
| 2 | +import logging |
2 | 3 | import operator
|
3 | 4 | import warnings
|
4 | 5 | from collections import OrderedDict
|
5 | 6 |
|
6 | 7 | import inflection
|
7 | 8 | from django.conf import settings
|
8 |
| -from django.db.models import Manager |
| 9 | +from django.db.models import Manager, Prefetch |
9 | 10 | from django.db.models.fields.related_descriptors import (
|
| 11 | + ForwardManyToOneDescriptor, |
10 | 12 | ManyToManyDescriptor,
|
11 | 13 | ReverseManyToOneDescriptor,
|
| 14 | + ReverseOneToOneDescriptor, |
12 | 15 | )
|
| 16 | +from django.db.models.query import QuerySet |
13 | 17 | from django.http import Http404
|
14 | 18 | from django.utils import encoding
|
15 | 19 | from django.utils.translation import gettext_lazy as _
|
16 | 20 | from rest_framework import exceptions
|
17 | 21 | from rest_framework.exceptions import APIException
|
| 22 | +from rest_framework.relations import RelatedField |
| 23 | +from rest_framework.request import Request |
| 24 | + |
| 25 | +from rest_framework_json_api.serializers import ModelSerializer, ValidationError |
18 | 26 |
|
19 | 27 | from .settings import json_api_settings
|
20 | 28 |
|
| 29 | +logger = logging.getLogger(__name__) |
| 30 | + |
21 | 31 | # Generic relation descriptor from django.contrib.contenttypes.
|
22 | 32 | if "django.contrib.contenttypes" not in settings.INSTALLED_APPS: # pragma: no cover
|
23 | 33 | # Target application does not use contenttypes. Importing would cause errors.
|
@@ -472,3 +482,144 @@ def format_errors(data):
|
472 | 482 | if len(data) > 1 and isinstance(data, list):
|
473 | 483 | data.sort(key=lambda x: x.get("source", {}).get("pointer", ""))
|
474 | 484 | return {"errors": data}
|
| 485 | + |
| 486 | + |
| 487 | +def get_expensive_relational_fields(serializer_class: ModelSerializer) -> list[str]: |
| 488 | + """ |
| 489 | + We define 'expensive' as relational fields on the serializer that don't correspond to a |
| 490 | + forward relation on the model. |
| 491 | + """ |
| 492 | + return [ |
| 493 | + field |
| 494 | + for field in getattr(serializer_class, 'included_serializers', {}) |
| 495 | + if not isinstance(getattr(serializer_class.Meta.model, field, None), ForwardManyToOneDescriptor) |
| 496 | + ] |
| 497 | + |
| 498 | + |
| 499 | +def get_cheap_relational_fields(serializer_class: ModelSerializer) -> list[str]: |
| 500 | + """ |
| 501 | + We define 'cheap' as relational fields on the serializer that _do_ correspond to a |
| 502 | + forward relation on the model. |
| 503 | + """ |
| 504 | + return [ |
| 505 | + field |
| 506 | + for field in getattr(serializer_class, 'included_serializers', {}) |
| 507 | + if isinstance(getattr(serializer_class.Meta.model, field, None), ForwardManyToOneDescriptor) |
| 508 | + ] |
| 509 | + |
| 510 | + |
| 511 | +def get_queryset_for_field(field: RelatedField) -> QuerySet: |
| 512 | + model_field_descriptor = getattr(field.parent.Meta.model, field.field_name) |
| 513 | + # NOTE: Important to check in this order, as some of these classes are ancestors of one |
| 514 | + # another (ie `ManyToManyDescriptor` subclasses `ReverseManyToOneDescriptor`) |
| 515 | + if isinstance(model_field_descriptor, ForwardManyToOneDescriptor): |
| 516 | + if (qs := field.queryset) is None: |
| 517 | + qs = model_field_descriptor.field.related_model._default_manager |
| 518 | + elif isinstance(model_field_descriptor, ManyToManyDescriptor): |
| 519 | + qs = field.child_relation.queryset |
| 520 | + elif isinstance(model_field_descriptor, ReverseManyToOneDescriptor): |
| 521 | + if (qs := field.child_relation.queryset) is None: |
| 522 | + qs = model_field_descriptor.field.model._default_manager |
| 523 | + elif isinstance(model_field_descriptor, ReverseOneToOneDescriptor): |
| 524 | + qs = model_field_descriptor.get_queryset() |
| 525 | + |
| 526 | + # Note: We call `.all()` before returning, as `_default_manager` may on occasion return a Manager |
| 527 | + # instance rather than a QuerySet, and we strictly want to be working with the latter. |
| 528 | + # (_default_manager is being used both direclty by us here, and by drf behind the scenes) |
| 529 | + # See: https://github.com/encode/django-rest-framework/blame/master/rest_framework/utils/field_mapping.py#L243 |
| 530 | + return qs.all() |
| 531 | + |
| 532 | + |
| 533 | +def add_nested_prefetches_to_qs( |
| 534 | + serializer_class: ModelSerializer, |
| 535 | + qs: QuerySet, |
| 536 | + request: Request, |
| 537 | + sparsefields: dict[str, list[str]], |
| 538 | + includes: dict, # TODO: Define typing as recursive once supported. |
| 539 | + select_related: str = '', |
| 540 | +) -> QuerySet: |
| 541 | + """ |
| 542 | + Prefetch all required data onto the supplied queryset, calling this method recursively for child |
| 543 | + serializers where needed. |
| 544 | + There is some added built-in optimisation here, attempting to opt for select_related calls over |
| 545 | + prefetches where possible -- it's only possible if the child serializers are interested |
| 546 | + exclusively in select_relating also. This is controlled with the `select_related` param. |
| 547 | + If `select_related` comes through, will attempt to instead build further onto this and return |
| 548 | + a dundered list of strings for the caller to use in a select_related call. If that fails, |
| 549 | + returns a qs as normal. |
| 550 | + """ |
| 551 | + # Determine fields that'll be returned by this serializer. |
| 552 | + resource_name = get_resource_type_from_serializer(serializer_class) |
| 553 | + logger.debug(f'ADDING NESTED PREFETCHES FOR: {resource_name}') |
| 554 | + dummy_serializer = serializer_class(context={'request': request, 'demanded_fields': sparsefields.get(resource_name, [])}) |
| 555 | + requested_fields = dummy_serializer.fields.keys() |
| 556 | + |
| 557 | + # Ensure any requested includes are in the fields list, else error loudly! |
| 558 | + if not includes.keys() <= requested_fields: |
| 559 | + errors = {f'{resource_name}.{field}': 'Field marked as include but not requested for serialization.' for field in includes.keys() - requested_fields} |
| 560 | + raise ValidationError(errors) |
| 561 | + |
| 562 | + included_serializers = get_included_serializers(serializer_class) |
| 563 | + |
| 564 | + # Iterate over all expensive relations and prefetch_related where needed. |
| 565 | + for field in get_expensive_relational_fields(serializer_class): |
| 566 | + if field in requested_fields: |
| 567 | + logger.debug(f'EXPENSIVE_FIELD: {field}') |
| 568 | + select_related = '' # wipe, cannot be used. :( |
| 569 | + if not hasattr(qs.model, field): |
| 570 | + # We might fall into here if, for example, there's an expensive |
| 571 | + # SerializerMethodResourceRelatedField defined. |
| 572 | + continue |
| 573 | + if field in includes: |
| 574 | + logger.debug('- PREFETCHING DEEP') |
| 575 | + # Prefetch and recurse. |
| 576 | + child_serializer_class = included_serializers[field] |
| 577 | + prefetch_qs = add_nested_prefetches_to_qs( |
| 578 | + child_serializer_class, |
| 579 | + get_queryset_for_field(dummy_serializer.fields[field]), |
| 580 | + request=request, |
| 581 | + sparsefields=sparsefields, |
| 582 | + includes=includes[field], |
| 583 | + ) |
| 584 | + qs = qs.prefetch_related(Prefetch(field, prefetch_qs)) |
| 585 | + else: |
| 586 | + logger.debug('- PREFETCHING SHALLOW') |
| 587 | + # Prefetch "shallowly"; we only care about ids. |
| 588 | + qs = qs.prefetch_related(field) # TODO: Still use ResourceRelatedField.qs if present! |
| 589 | + |
| 590 | + # Iterate over all cheap (forward) relations and select_related (or prefetch) where needed. |
| 591 | + new_select_related = [select_related] |
| 592 | + for field in get_cheap_relational_fields(serializer_class): |
| 593 | + if field in requested_fields: |
| 594 | + logger.debug(f'CHEAP_FIELD: {field}') |
| 595 | + if field in includes: |
| 596 | + logger.debug('- present in includes') |
| 597 | + # Recurse and see if we get a prefetch qs back, or a select_related string. |
| 598 | + child_serializer_class = included_serializers[field] |
| 599 | + prefetch_qs_or_select_related_str = add_nested_prefetches_to_qs( |
| 600 | + child_serializer_class, |
| 601 | + get_queryset_for_field(dummy_serializer.fields[field]), |
| 602 | + request=request, |
| 603 | + sparsefields=sparsefields, |
| 604 | + includes=includes[field], |
| 605 | + select_related=field, |
| 606 | + ) |
| 607 | + if isinstance(prefetch_qs_or_select_related_str, list): |
| 608 | + logger.debug(f'SELECTING RELATED: {prefetch_qs_or_select_related_str}') |
| 609 | + # Prefetch has come back as a list of (dundered) strings. |
| 610 | + # We append onto existing select_related string, to potentially pass back up |
| 611 | + # and also feed it directly into a select_related call in case the former |
| 612 | + # falls through. |
| 613 | + if select_related: |
| 614 | + for sr in prefetch_qs_or_select_related_str: |
| 615 | + new_select_related.append(f'{select_related}__{sr}') |
| 616 | + qs = qs.select_related(*prefetch_qs_or_select_related_str) |
| 617 | + else: |
| 618 | + # Select related option fell through, we need to do a prefetch. :( |
| 619 | + logger.debug(f'PREFETCHING RELATED: {field}') |
| 620 | + select_related = '' |
| 621 | + qs = qs.prefetch_related(Prefetch(field, prefetch_qs_or_select_related_str)) |
| 622 | + |
| 623 | + if select_related: |
| 624 | + return new_select_related |
| 625 | + return qs |
0 commit comments