Skip to content

Commit 4e9b43c

Browse files
committed
tests and improvements for related fields
1 parent 4a70020 commit 4e9b43c

File tree

2 files changed

+62
-57
lines changed

2 files changed

+62
-57
lines changed

example/tests/test_openapi.py

+51-50
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from django.test import RequestFactory, override_settings
66
from rest_framework.request import Request
77

8+
from rest_framework_json_api.management.commands.generateschema import Command
89
from rest_framework_json_api.schemas.openapi import AutoSchema, SchemaGenerator
9-
from rest_framework_json_api.views import ModelViewSet
1010

11-
from example import models, serializers, views
11+
from example import views
12+
from example.tests import TestBase
1213

1314

1415
def create_request(path):
@@ -17,12 +18,6 @@ def create_request(path):
1718
return request
1819

1920

20-
def create_view(view_cls, method, request):
21-
generator = SchemaGenerator()
22-
view = generator.create_view(view_cls.as_view(), method, request)
23-
return view
24-
25-
2621
def create_view_with_kw(view_cls, method, request, initkwargs):
2722
generator = SchemaGenerator()
2823
view = generator.create_view(view_cls.as_view(initkwargs), method, request)
@@ -132,45 +127,51 @@ def test_schema_construction():
132127
assert 'components' in schema
133128

134129

135-
# TODO: figure these out
136-
def test_schema_related():
137-
class AuthorBioViewSet(ModelViewSet):
138-
queryset = models.AuthorBio.objects.all()
139-
serializer_class = serializers.AuthorBioSerializer
140-
141-
patterns = [
142-
url(r'^authors/(?P<pk>[^/.]+)/(?P<related_field>\w+)/$',
143-
views.AuthorViewSet.as_view({'get': 'retrieve_related'}),
144-
name='author-related'),
145-
url(r'^bios/(?P<pk>[^/.]+)/$',
146-
AuthorBioViewSet,
147-
name='author-bio')
148-
]
149-
generator = SchemaGenerator(patterns=patterns)
150-
151-
request = create_request('/authors/123/bio/')
152-
schema = generator.get_schema(request=request)
153-
# TODO: finish this test
154-
print(schema)
155-
156-
# def test_retrieve_relationships():
157-
# path = '/authors/{id}/relationships/bio/'
158-
# method = 'GET'
159-
#
160-
# view = create_view_with_kw(
161-
# views.AuthorViewSet,
162-
# method,
163-
# create_request(path),
164-
# {'get': 'retrieve_related'}
165-
# )
166-
# inspector = AutoSchema()
167-
# inspector.view = view
168-
#
169-
# operation = inspector.get_operation(path, method)
170-
# assert 'responses' in operation
171-
# assert '200' in operation['responses']
172-
# resp = operation['responses']['200']['content']
173-
# data = resp['application/vnd.api+json']['schema']['properties']['data']
174-
# assert data['type'] == 'object'
175-
# assert data['required'] == ['type', 'id']
176-
# assert data['properties']['type'] == {'$ref': '#/components/schemas/type'}
130+
def test_generateschema_command():
131+
command = Command()
132+
assert command.get_generator_class() == SchemaGenerator
133+
134+
135+
class TestSchemaRelatedField(TestBase):
136+
def test_schema_related_serializers(self):
137+
"""
138+
Confirm that paths are generated for related fields. For example:
139+
url path '/authors/{pk}/{related_field>}/' generates:
140+
/authors/{id}/comments/
141+
/authors/{id}/entries/
142+
/authors/{id}/first_entry/
143+
and confirm that the schema for the related field is properly rendered
144+
"""
145+
generator = SchemaGenerator()
146+
request = create_request('/')
147+
schema = generator.get_schema(request=request)
148+
assert '/authors/{id}/comments/' in schema['paths']
149+
assert '/authors/{id}/entries/' in schema['paths']
150+
assert '/authors/{id}/first_entry/' in schema['paths']
151+
first_get = schema['paths']['/authors/{id}/first_entry/']['get']['responses']['200']
152+
first_schema = first_get['content']['application/vnd.api+json']['schema']
153+
first_props = first_schema['properties']['data']['properties']['attributes']['properties']
154+
assert 'headline' in first_props
155+
assert first_props['headline'] == {'type': 'string', 'maxLength': 255}
156+
157+
# def test_retrieve_relationships(self):
158+
# path = '/authors/{id}/relationships/bio/'
159+
# method = 'GET'
160+
#
161+
# view = create_view_with_kw(
162+
# views.AuthorViewSet,
163+
# method,
164+
# create_request(path),
165+
# {'get': 'retrieve_related'}
166+
# )
167+
# inspector = AutoSchema()
168+
# inspector.view = view
169+
#
170+
# operation = inspector.get_operation(path, method)
171+
# assert 'responses' in operation
172+
# assert '200' in operation['responses']
173+
# resp = operation['responses']['200']['content']
174+
# data = resp['application/vnd.api+json']['schema']['properties']['data']
175+
# assert data['type'] == 'object'
176+
# assert data['required'] == ['type', 'id']
177+
# assert data['properties']['type'] == {'$ref': '#/components/schemas/type'}

rest_framework_json_api/schemas/openapi.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,13 @@ def get_paths(self, request=None):
312312
# instead of doing it here.
313313
expanded_endpoints = []
314314
for path, method, view in view_endpoints:
315+
action = view.action if hasattr(view, 'action') else None
315316
if isinstance(view, RelationshipView):
316317
expanded_endpoints += self._expand_relationships(path, method, view)
317-
elif view.action == 'retrieve_related':
318+
elif action == 'retrieve_related':
318319
expanded_endpoints += self._expand_related(path, method, view, view_endpoints)
319320
else:
320-
expanded_endpoints.append((path, method, view, view.action))
321+
expanded_endpoints.append((path, method, view, action))
321322

322323
for path, method, view, action in expanded_endpoints:
323324
if not self.has_view_permissions(path, method, view):
@@ -379,12 +380,15 @@ def _expand_related(self, path, method, view, view_endpoints):
379380
"""
380381
result = []
381382
serializer = view.get_serializer()
383+
# It's not obvious if it's allowed to have both included_ and related_ serializers,
384+
# so just merge both dicts.
385+
serializers = {}
386+
if hasattr(serializer, 'included_serializers'):
387+
serializers = {**serializers, **serializer.included_serializers}
382388
if hasattr(serializer, 'related_serializers'):
383-
related_fields = [fs for fs in serializer.related_serializers.items()]
384-
elif hasattr(serializer, 'included_serializers'):
385-
related_fields = [fs for fs in serializer.included_serializers.items()]
386-
else:
387-
related_fields = []
389+
serializers = {**serializers, **serializer.related_serializers}
390+
related_fields = [fs for fs in serializers.items()]
391+
388392
for field, related_serializer in related_fields:
389393
related_view = self._find_related_view(view_endpoints, related_serializer, view)
390394
if related_view:

0 commit comments

Comments
 (0)