Skip to content

Allow users to overwrite get_serializer_class while using related urls #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ Sergey Kolomenkin <https://kolomenkin.com>
Stas S. <[email protected]>
Tim Selman <[email protected]>
Tom Glowka <[email protected]>
Ulrich Schuster <[email protected]>
Yaniv Peer <[email protected]>
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ any parts of the framework not mentioned in the documentation should generally b

* Ability for the user to select `included_serializers` to apply when using `BrowsableAPI`, based on available `included_serializers` defined for the current endpoint.

### Fixed

* Allow users to overwrite a view's `get_serializer_class()` method when using [related urls](https://django-rest-framework-json-api.readthedocs.io/en/stable/usage.html#related-urls)


## [4.0.0] - 2020-10-31

Expand Down
8 changes: 8 additions & 0 deletions example/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ def get_first_entry(self, obj):
return obj.entries.first()


class AuthorListSerializer(AuthorSerializer):
pass


class AuthorDetailSerializer(AuthorSerializer):
pass


class WriterSerializer(serializers.ModelSerializer):
included_serializers = {
'bio': AuthorBioSerializer
Expand Down
4 changes: 2 additions & 2 deletions example/tests/snapshots/snap_test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"properties": {
"data": {
"items": {
"$ref": "#/components/schemas/Author"
"$ref": "#/components/schemas/AuthorList"
},
"type": "array"
},
Expand Down Expand Up @@ -171,7 +171,7 @@
"schema": {
"properties": {
"data": {
"$ref": "#/components/schemas/Author"
"$ref": "#/components/schemas/AuthorDetail"
},
"included": {
"items": {
Expand Down
14 changes: 7 additions & 7 deletions example/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,32 +367,32 @@ def test_get_related_instance_model_field(self):
got = view.get_related_instance()
self.assertEqual(got, self.author.id)

def test_get_serializer_class(self):
def test_get_related_serializer_class(self):
kwargs = {'pk': self.author.id, 'related_field': 'bio'}
view = self._get_view(kwargs)
got = view.get_serializer_class()
got = view.get_related_serializer_class()
self.assertEqual(got, AuthorBioSerializer)

def test_get_serializer_class_many(self):
def test_get_related_serializer_class_many(self):
kwargs = {'pk': self.author.id, 'related_field': 'entries'}
view = self._get_view(kwargs)
got = view.get_serializer_class()
got = view.get_related_serializer_class()
self.assertEqual(got, EntrySerializer)

def test_get_serializer_comes_from_included_serializers(self):
kwargs = {'pk': self.author.id, 'related_field': 'type'}
view = self._get_view(kwargs)
related_serializers = view.serializer_class.related_serializers
delattr(view.serializer_class, 'related_serializers')
got = view.get_serializer_class()
got = view.get_related_serializer_class()
self.assertEqual(got, AuthorTypeSerializer)

view.serializer_class.related_serializers = related_serializers

def test_get_serializer_class_raises_error(self):
def test_get_related_serializer_class_raises_error(self):
kwargs = {'pk': self.author.id, 'related_field': 'unknown'}
view = self._get_view(kwargs)
self.assertRaises(NotFound, view.get_serializer_class)
self.assertRaises(NotFound, view.get_related_serializer_class)

def test_retrieve_related_single_reverse_lookup(self):
url = reverse('author-related', kwargs={'pk': self.author.pk, 'related_field': 'bio'})
Expand Down
13 changes: 12 additions & 1 deletion example/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from example.models import Author, Blog, Comment, Company, Entry, Project, ProjectType
from example.serializers import (
AuthorDetailSerializer,
AuthorListSerializer,
AuthorSerializer,
BlogDRFSerializer,
BlogSerializer,
Expand Down Expand Up @@ -185,7 +187,16 @@ class NoFiltersetEntryViewSet(EntryViewSet):

class AuthorViewSet(ModelViewSet):
queryset = Author.objects.all()
serializer_class = AuthorSerializer
serializer_classes = {
"list": AuthorListSerializer,
"retrieve": AuthorDetailSerializer}
serializer_class = AuthorSerializer # fallback

def get_serializer_class(self):
try:
return self.serializer_classes.get(self.action, self.serializer_class)
except AttributeError:
return self.serializer_class


class CommentViewSet(ModelViewSet):
Expand Down
5 changes: 4 additions & 1 deletion rest_framework_json_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def get_resource_name(context, expand_polymorphic_types=False):
resource_name = getattr(view, 'resource_name')
except AttributeError:
try:
serializer = view.get_serializer_class()
if 'kwargs' in context and 'related_field' in context['kwargs']:
serializer = view.get_related_serializer_class()
else:
serializer = view.get_serializer_class()
if expand_polymorphic_types and issubclass(serializer, PolymorphicModelSerializer):
return serializer.get_polymorphic_types()
else:
Expand Down
12 changes: 9 additions & 3 deletions rest_framework_json_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,15 @@ def retrieve_related(self, request, *args, **kwargs):
if isinstance(instance, Iterable):
serializer_kwargs['many'] = True

serializer = self.get_serializer(instance, **serializer_kwargs)
serializer = self.get_related_serializer(instance, **serializer_kwargs)
return Response(serializer.data)

def get_serializer_class(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this exposes the upstream DRF get_serializer_class(). It seems this needs to remain for compatibility. @sliverc Do you agree?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My whole point in making the change here is to do exactly this: expose the upstream method, so that it can be overwritten. get_serializer_class() is a public method on the base class. As such, it is part of DRF's public API - the DRF documentation explicitly mentions that the method can be overwritten. In my opinion, an extension library should preserve the upstream API.

The current version of DJA does break the path to the upstream method. There might be cases where DJA users rely on this bug in some way. The change here would break their code. I'm just a first-time user of DJA without any idea about how it is being used in other projects, what the effects might be, and who might be affected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgive me for being dense but how is get_serializer_class() in DJA not able to be overridden by you? You've completely removed the function from DJA; not just added get_related_serializer().

This extension library routinely overrides the upstream DRF because it is extending it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do Full understand that you don't want to run the risk of breaking someone's application. And given that I am fairly new to DJA, I might not fully understand its inner workings, either.

I came across the issue the present PR is trying to solve, when I had to overwrite get_serializer_class() on a custom viewset, in exactly the same way I now added as part of the AuthorViewSet in the test example. I investigated the issue and discovered that when processing a GET request for a related instance, RelatedMixin.get_related_instance() directly accessed the self.serializer_class field instead of getting the class via its accessor method. Consequently, the upstream get_serializer_class() would not be called, either, which would be my overwritten version. Hence my remark that DJA effectively prevents get_serializer_class() to be overwritten on a custom viewset, which I consider to be a bug.

My first attempt to fix the issue was to have RelatedMixin.get_related_instance call the upstream method instead of accessing the field directly. But this breaks things badly. During two exchanges with @sliverc (see the comments on the issue here) showed that the functionality of RelatedMixin.get_serializer_class is actually not about getting the serializer class, but getting the class of related serializers. I followed the hint of @sliverc to rename the method from get_serializer_class() to get_related_serializer_class(), which better captures its functionality. This way, RelatedMixin.get_related_instance can call get_serializer_class() of the parent class (potentially overwritten), while all functionality that has to do with getting serializers for related views is now handled by aptly named get_related_* methods.

To make this work, I need to do some other adjustments along the call chain. In particular, retrieve_related(), which is the entry point for a GET request on a related resource, now calls the renamed get_related_serializer_class(). One more change, also pointed out by @sliverc, was in utils.get_resource_name(): Here, I had to make the distinction if the serializer of the present viewset is needed, or if the serializer of a related viewset should be returned. The distinction between the two cases depends on the related_field kwarg, similar to the condition in get_serializer_class()(not get_related_serializer_class().

I hope this helps to provide some understanding of what I am trying to accomplish. I was very much focussed on the functionality of getting related resources; because of my limited understanding of DJA, I am not able to tell if other parts of its functionality are affected by the changes introduced in RelatedMixin. Here, I relied on the tests, which are all passing. If there is something else I can do to clear things up or to provide better explanation, please let me know. I am also available for a quick online call, if that can help to better assess the changes I introduced.

Copy link
Member

@sliverc sliverc Nov 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to shed some light on this... RelatedMixin mistakenly shadowed get_serializer_class to get related serializer. Shadowing happened as RelatedMixin only works when used with a DRF generic view which provides a get_serializer_class method
For RelatedMixin to be able to get parent serializer it then simply used self.serializer_class which is not DRF conform. However this only happened when related urls have been configured as otherwise this if is always false and there is no difference to the method not being overwritten in RelatedMixin.

If someone overwrote get_serializer_class in its view without using related urls it would just work as RelatedMixin did not do anything anyway in such a case then passing it on to super. If someone tried to overwrite get_serializer_class with related urls this user would run into exactly the same problem as @uliSchuster has.

Therefore I don't see that this change affects existing users.

def get_related_serializer(self, instance, **kwargs):
serializer_class = self.get_related_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(instance, **kwargs)

def get_related_serializer_class(self):
parent_serializer_class = super(RelatedMixin, self).get_serializer_class()

if 'related_field' in self.kwargs:
Expand Down Expand Up @@ -179,7 +184,8 @@ def get_related_field_name(self):

def get_related_instance(self):
parent_obj = self.get_object()
parent_serializer = self.serializer_class(parent_obj)
parent_serializer_class = self.get_serializer_class()
parent_serializer = parent_serializer_class(parent_obj)
field_name = self.get_related_field_name()
field = parent_serializer.fields.get(field_name, None)

Expand Down