|
8 | 8 | from django.db.models import Manager
|
9 | 9 | from django.utils import encoding, six
|
10 | 10 | from rest_framework import relations, renderers
|
| 11 | +from rest_framework.fields import SkipField, get_attribute |
| 12 | +from rest_framework.relations import PKOnlyObject |
11 | 13 | from rest_framework.serializers import BaseSerializer, ListSerializer, Serializer
|
12 | 14 | from rest_framework.settings import api_settings
|
13 | 15 |
|
@@ -297,34 +299,20 @@ def extract_relationships(cls, fields, resource, resource_instance):
|
297 | 299 | return utils._format_object(data)
|
298 | 300 |
|
299 | 301 | @classmethod
|
300 |
| - def extract_relation_instance(cls, field_name, field, resource_instance, serializer): |
| 302 | + def extract_relation_instance(cls, field, resource_instance): |
301 | 303 | """
|
302 | 304 | Determines what instance represents given relation and extracts it.
|
303 | 305 |
|
304 |
| - Relation instance is determined by given field_name or source configured on |
305 |
| - field. As fallback is a serializer method called with name of field's source. |
| 306 | + Relation instance is determined exactly same way as it determined |
| 307 | + in parent serializer |
306 | 308 | """
|
307 |
| - relation_instance = None |
308 |
| - |
309 | 309 | try:
|
310 |
| - relation_instance = getattr(resource_instance, field_name) |
311 |
| - except AttributeError: |
312 |
| - try: |
313 |
| - # For ManyRelatedFields if `related_name` is not set |
314 |
| - # we need to access `foo_set` from `source` |
315 |
| - relation_instance = getattr(resource_instance, field.child_relation.source) |
316 |
| - except AttributeError: |
317 |
| - if hasattr(serializer, field.source): |
318 |
| - serializer_method = getattr(serializer, field.source) |
319 |
| - relation_instance = serializer_method(resource_instance) |
320 |
| - else: |
321 |
| - # case when source is a simple remap on resource_instance |
322 |
| - try: |
323 |
| - relation_instance = getattr(resource_instance, field.source) |
324 |
| - except AttributeError: |
325 |
| - pass |
326 |
| - |
327 |
| - return relation_instance |
| 310 | + res = field.get_attribute(resource_instance) |
| 311 | + if isinstance(res, PKOnlyObject): |
| 312 | + return get_attribute(resource_instance, field.source_attrs) |
| 313 | + return res |
| 314 | + except SkipField: |
| 315 | + return None |
328 | 316 |
|
329 | 317 | @classmethod
|
330 | 318 | def extract_included(cls, fields, resource, resource_instance, included_resources,
|
@@ -363,7 +351,7 @@ def extract_included(cls, fields, resource, resource_instance, included_resource
|
363 | 351 | continue
|
364 | 352 |
|
365 | 353 | relation_instance = cls.extract_relation_instance(
|
366 |
| - field_name, field, resource_instance, current_serializer |
| 354 | + field, resource_instance |
367 | 355 | )
|
368 | 356 | if isinstance(relation_instance, Manager):
|
369 | 357 | relation_instance = relation_instance.all()
|
|
0 commit comments