diff --git a/AUTHORS b/AUTHORS index b440df81..d4c03b2c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -33,4 +33,5 @@ Sergey Kolomenkin <https://kolomenkin.com> Stas S. <stas@nerd.ro> Tim Selman <timcbaoth@gmail.com> Tom Glowka <glowka.tom@gmail.com> +Ulrich Schuster <ulrich.schuster@mailworks.org> Yaniv Peer <yanivpeer@gmail.com> diff --git a/CHANGELOG.md b/CHANGELOG.md index 12b3e805..4613aa18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ any parts of the framework not mentioned in the documentation should generally b * Ability for the user to select `included_serializers` to apply when using `BrowsableAPI`, based on available `included_serializers` defined for the current endpoint. +### Fixed + +* Allow users to overwrite a view's `get_serializer_class()` method when using [related urls](https://django-rest-framework-json-api.readthedocs.io/en/stable/usage.html#related-urls) + ## [4.0.0] - 2020-10-31 diff --git a/example/serializers.py b/example/serializers.py index 9dc84a4a..566f39d5 100644 --- a/example/serializers.py +++ b/example/serializers.py @@ -261,6 +261,14 @@ def get_first_entry(self, obj): return obj.entries.first() +class AuthorListSerializer(AuthorSerializer): + pass + + +class AuthorDetailSerializer(AuthorSerializer): + pass + + class WriterSerializer(serializers.ModelSerializer): included_serializers = { 'bio': AuthorBioSerializer diff --git a/example/tests/snapshots/snap_test_openapi.py b/example/tests/snapshots/snap_test_openapi.py index ec8da388..aca41d70 100644 --- a/example/tests/snapshots/snap_test_openapi.py +++ b/example/tests/snapshots/snap_test_openapi.py @@ -65,7 +65,7 @@ "properties": { "data": { "items": { - "$ref": "#/components/schemas/Author" + "$ref": "#/components/schemas/AuthorList" }, "type": "array" }, @@ -171,7 +171,7 @@ "schema": { "properties": { "data": { - "$ref": "#/components/schemas/Author" + "$ref": "#/components/schemas/AuthorDetail" }, "included": { "items": { diff --git a/example/tests/test_views.py b/example/tests/test_views.py index 9cd493f7..25eeca89 100644 --- a/example/tests/test_views.py +++ b/example/tests/test_views.py @@ -367,16 +367,16 @@ def test_get_related_instance_model_field(self): got = view.get_related_instance() self.assertEqual(got, self.author.id) - def test_get_serializer_class(self): + def test_get_related_serializer_class(self): kwargs = {'pk': self.author.id, 'related_field': 'bio'} view = self._get_view(kwargs) - got = view.get_serializer_class() + got = view.get_related_serializer_class() self.assertEqual(got, AuthorBioSerializer) - def test_get_serializer_class_many(self): + def test_get_related_serializer_class_many(self): kwargs = {'pk': self.author.id, 'related_field': 'entries'} view = self._get_view(kwargs) - got = view.get_serializer_class() + got = view.get_related_serializer_class() self.assertEqual(got, EntrySerializer) def test_get_serializer_comes_from_included_serializers(self): @@ -384,15 +384,15 @@ def test_get_serializer_comes_from_included_serializers(self): view = self._get_view(kwargs) related_serializers = view.serializer_class.related_serializers delattr(view.serializer_class, 'related_serializers') - got = view.get_serializer_class() + got = view.get_related_serializer_class() self.assertEqual(got, AuthorTypeSerializer) view.serializer_class.related_serializers = related_serializers - def test_get_serializer_class_raises_error(self): + def test_get_related_serializer_class_raises_error(self): kwargs = {'pk': self.author.id, 'related_field': 'unknown'} view = self._get_view(kwargs) - self.assertRaises(NotFound, view.get_serializer_class) + self.assertRaises(NotFound, view.get_related_serializer_class) def test_retrieve_related_single_reverse_lookup(self): url = reverse('author-related', kwargs={'pk': self.author.pk, 'related_field': 'bio'}) diff --git a/example/views.py b/example/views.py index 8c80d145..99a54193 100644 --- a/example/views.py +++ b/example/views.py @@ -15,6 +15,8 @@ from example.models import Author, Blog, Comment, Company, Entry, Project, ProjectType from example.serializers import ( + AuthorDetailSerializer, + AuthorListSerializer, AuthorSerializer, BlogDRFSerializer, BlogSerializer, @@ -185,7 +187,16 @@ class NoFiltersetEntryViewSet(EntryViewSet): class AuthorViewSet(ModelViewSet): queryset = Author.objects.all() - serializer_class = AuthorSerializer + serializer_classes = { + "list": AuthorListSerializer, + "retrieve": AuthorDetailSerializer} + serializer_class = AuthorSerializer # fallback + + def get_serializer_class(self): + try: + return self.serializer_classes.get(self.action, self.serializer_class) + except AttributeError: + return self.serializer_class class CommentViewSet(ModelViewSet): diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 2443575f..b99de91a 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -52,7 +52,10 @@ def get_resource_name(context, expand_polymorphic_types=False): resource_name = getattr(view, 'resource_name') except AttributeError: try: - serializer = view.get_serializer_class() + if 'kwargs' in context and 'related_field' in context['kwargs']: + serializer = view.get_related_serializer_class() + else: + serializer = view.get_serializer_class() if expand_polymorphic_types and issubclass(serializer, PolymorphicModelSerializer): return serializer.get_polymorphic_types() else: diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index 2e0b22ef..7c874e7a 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -144,10 +144,15 @@ def retrieve_related(self, request, *args, **kwargs): if isinstance(instance, Iterable): serializer_kwargs['many'] = True - serializer = self.get_serializer(instance, **serializer_kwargs) + serializer = self.get_related_serializer(instance, **serializer_kwargs) return Response(serializer.data) - def get_serializer_class(self): + def get_related_serializer(self, instance, **kwargs): + serializer_class = self.get_related_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + return serializer_class(instance, **kwargs) + + def get_related_serializer_class(self): parent_serializer_class = super(RelatedMixin, self).get_serializer_class() if 'related_field' in self.kwargs: @@ -179,7 +184,8 @@ def get_related_field_name(self): def get_related_instance(self): parent_obj = self.get_object() - parent_serializer = self.serializer_class(parent_obj) + parent_serializer_class = self.get_serializer_class() + parent_serializer = parent_serializer_class(parent_obj) field_name = self.get_related_field_name() field = parent_serializer.fields.get(field_name, None)