|
1 | 1 | import warnings
|
2 | 2 | from urllib.parse import urljoin
|
3 | 3 |
|
| 4 | +from django.db.models.fields import related_descriptors as rd |
4 | 5 | from django.utils.module_loading import import_string as import_class_from_dotted_path
|
5 | 6 | from rest_framework.fields import empty
|
6 | 7 | from rest_framework.relations import ManyRelatedField
|
7 | 8 | from rest_framework.schemas import openapi as drf_openapi
|
8 | 9 | from rest_framework.schemas.utils import is_list_view
|
9 | 10 |
|
10 | 11 | from rest_framework_json_api import serializers
|
| 12 | +from rest_framework_json_api.views import RelationshipView |
11 | 13 |
|
12 | 14 |
|
13 | 15 | class SchemaGenerator(drf_openapi.SchemaGenerator):
|
@@ -300,7 +302,9 @@ def get_schema(self, request=None, public=False):
|
300 | 302 | #: - 'action' copy of current view.action (list/fetch) as this gets reset for each request.
|
301 | 303 | expanded_endpoints = []
|
302 | 304 | for path, method, view in view_endpoints:
|
303 |
| - if hasattr(view, 'action') and view.action == 'retrieve_related': |
| 305 | + if isinstance(view, RelationshipView): |
| 306 | + expanded_endpoints += self._expand_relationships(path, method, view) |
| 307 | + elif hasattr(view, 'action') and view.action == 'retrieve_related': |
304 | 308 | expanded_endpoints += self._expand_related(path, method, view, view_endpoints)
|
305 | 309 | else:
|
306 | 310 | expanded_endpoints.append((path, method, view, getattr(view, 'action', None)))
|
@@ -346,6 +350,28 @@ def get_schema(self, request=None, public=False):
|
346 | 350 |
|
347 | 351 | return schema
|
348 | 352 |
|
| 353 | + def _expand_relationships(self, path, method, view): |
| 354 | + """ |
| 355 | + Expand path containing .../{id}/relationships/{related_field} into list of related fields. |
| 356 | + :return:list[tuple(path, method, view, action)] |
| 357 | + """ |
| 358 | + queryset = view.get_queryset() |
| 359 | + if not queryset.model: |
| 360 | + return [(path, method, view, getattr(view, 'action', '')), ] |
| 361 | + result = [] |
| 362 | + # TODO: what about serializer-only (non-model) fields? |
| 363 | + # Shouldn't this be iterating over serializer fields rather than model fields? |
| 364 | + # Look at parent view's serializer to get the list of fields. |
| 365 | + # OR maybe like _expand_related? |
| 366 | + m = queryset.model |
| 367 | + for field in [f for f in dir(m) if not f.startswith('_')]: |
| 368 | + attr = getattr(m, field) |
| 369 | + if isinstance(attr, (rd.ReverseManyToOneDescriptor, rd.ForwardOneToOneDescriptor)): |
| 370 | + action = 'rels' if isinstance(attr, rd.ReverseManyToOneDescriptor) else 'rel' |
| 371 | + result.append((path.replace('{related_field}', field), method, view, action)) |
| 372 | + |
| 373 | + return result |
| 374 | + |
349 | 375 | def _expand_related(self, path, method, view, view_endpoints):
|
350 | 376 | """
|
351 | 377 | Expand path containing .../{id}/{related_field} into list of related fields
|
|
0 commit comments