diff --git a/CHANGELOG.md b/CHANGELOG.md index b816f47d..c78a1b12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ any parts of the framework not mentioned in the documentation should generally b * Added support for Django 3.2. +### Fixed + +* Allow `get_serializer_class` to be overwritten when using related urls without defining `serializer_class` fallback + ## [4.1.0] - 2021-03-08 ### Added diff --git a/example/tests/test_views.py b/example/tests/test_views.py index 4b494b52..16b51622 100644 --- a/example/tests/test_views.py +++ b/example/tests/test_views.py @@ -417,12 +417,11 @@ def test_get_related_serializer_class_many(self): def test_get_serializer_comes_from_included_serializers(self): kwargs = {"pk": self.author.id, "related_field": "type"} view = self._get_view(kwargs) - related_serializers = view.serializer_class.related_serializers - delattr(view.serializer_class, "related_serializers") + related_serializers = view.get_serializer_class().related_serializers + delattr(view.get_serializer_class(), "related_serializers") got = view.get_related_serializer_class() self.assertEqual(got, AuthorTypeSerializer) - - view.serializer_class.related_serializers = related_serializers + view.get_serializer_class().related_serializers = related_serializers def test_get_related_serializer_class_raises_error(self): kwargs = {"pk": self.author.id, "related_field": "unknown"} diff --git a/example/views.py b/example/views.py index 65bcb301..6a1b15a6 100644 --- a/example/views.py +++ b/example/views.py @@ -208,17 +208,15 @@ class NoFiltersetEntryViewSet(EntryViewSet): class AuthorViewSet(ModelViewSet): queryset = Author.objects.all() - serializer_classes = { - "list": AuthorListSerializer, - "retrieve": AuthorDetailSerializer, - } - serializer_class = AuthorSerializer # fallback def get_serializer_class(self): - try: - return self.serializer_classes.get(self.action, self.serializer_class) - except AttributeError: - return self.serializer_class + serializer_classes = { + "list": AuthorListSerializer, + "retrieve": AuthorDetailSerializer, + } + + action = getattr(self, "action", "") + return serializer_classes.get(action, AuthorSerializer) class CommentViewSet(ModelViewSet): diff --git a/rest_framework_json_api/views.py b/rest_framework_json_api/views.py index 2f061cbb..3df27d1f 100644 --- a/rest_framework_json_api/views.py +++ b/rest_framework_json_api/views.py @@ -154,7 +154,7 @@ def get_related_serializer(self, instance, **kwargs): return serializer_class(instance, **kwargs) def get_related_serializer_class(self): - parent_serializer_class = super(RelatedMixin, self).get_serializer_class() + parent_serializer_class = self.get_serializer_class() if "related_field" in self.kwargs: field_name = self.kwargs["related_field"]