Skip to content

Commit 960b258

Browse files
committed
Adds the following features:
- support for post and patch request on polymorphic model endpoints. - makes polymorphic serializers give child fields instead of its own.
1 parent 7809f75 commit 960b258

File tree

6 files changed

+164
-16
lines changed

6 files changed

+164
-16
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
# Generated by Django 1.9.6 on 2016-05-13 08:57
3+
from __future__ import unicode_literals
4+
5+
from django.db import migrations, models
6+
import django.db.models.deletion
7+
8+
9+
class Migration(migrations.Migration):
10+
11+
dependencies = [
12+
('contenttypes', '0002_remove_content_type_name'),
13+
('example', '0001_initial'),
14+
]
15+
16+
operations = [
17+
migrations.CreateModel(
18+
name='Company',
19+
fields=[
20+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
21+
('name', models.CharField(max_length=100)),
22+
],
23+
),
24+
migrations.CreateModel(
25+
name='Project',
26+
fields=[
27+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
28+
('topic', models.CharField(max_length=30)),
29+
],
30+
options={
31+
'abstract': False,
32+
},
33+
),
34+
migrations.CreateModel(
35+
name='ArtProject',
36+
fields=[
37+
('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')),
38+
('artist', models.CharField(max_length=30)),
39+
],
40+
options={
41+
'abstract': False,
42+
},
43+
bases=('example.project',),
44+
),
45+
migrations.CreateModel(
46+
name='ResearchProject',
47+
fields=[
48+
('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')),
49+
('supervisor', models.CharField(max_length=30)),
50+
],
51+
options={
52+
'abstract': False,
53+
},
54+
bases=('example.project',),
55+
),
56+
migrations.AddField(
57+
model_name='project',
58+
name='polymorphic_ctype',
59+
field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_example.project_set+', to='contenttypes.ContentType'),
60+
),
61+
migrations.AddField(
62+
model_name='company',
63+
name='current_project',
64+
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='companies', to='example.Project'),
65+
),
66+
migrations.AddField(
67+
model_name='company',
68+
name='future_projects',
69+
field=models.ManyToManyField(to='example.Project'),
70+
),
71+
]

example/serializers.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from datetime import datetime
2-
from rest_framework_json_api import serializers, relations
2+
from django.db.models.query import QuerySet
3+
from rest_framework.utils.serializer_helpers import BindingDict
4+
from rest_framework_json_api import serializers, relations, utils
35
from example import models
46

57

@@ -43,13 +45,13 @@ def __init__(self, *args, **kwargs):
4345
source='comment_set', many=True, read_only=True)
4446
# many related from serializer
4547
suggested = relations.SerializerMethodResourceRelatedField(
46-
source='get_suggested', model=Entry, many=True, read_only=True)
48+
source='get_suggested', model=models.Entry, many=True, read_only=True)
4749
# single related from serializer
4850
featured = relations.SerializerMethodResourceRelatedField(
49-
source='get_featured', model=Entry, read_only=True)
51+
source='get_featured', model=models.Entry, read_only=True)
5052

5153
def get_suggested(self, obj):
52-
return models.Entry.objects.exclude(pk=obj.pk).first()
54+
return models.Entry.objects.exclude(pk=obj.pk)
5355

5456
def get_featured(self, obj):
5557
return models.Entry.objects.exclude(pk=obj.pk).first()
@@ -107,19 +109,48 @@ class Meta:
107109

108110
class ProjectSerializer(serializers.ModelSerializer):
109111

112+
polymorphic_serializers = [
113+
{'model': models.ArtProject, 'serializer': ArtProjectSerializer},
114+
{'model': models.ResearchProject, 'serializer': ResearchProjectSerializer},
115+
]
116+
110117
class Meta:
111118
model = models.Project
112119
exclude = ('polymorphic_ctype',)
113120

121+
def _get_actual_serializer_from_instance(self, instance):
122+
for info in self.polymorphic_serializers:
123+
if isinstance(instance, info.get('model')):
124+
actual_serializer = info.get('serializer')
125+
return actual_serializer(instance, context=self.context)
126+
127+
@property
128+
def fields(self):
129+
_fields = BindingDict(self)
130+
for key, value in self.get_fields().items():
131+
_fields[key] = value
132+
return _fields
133+
134+
def get_fields(self):
135+
if self.instance is not None:
136+
if not isinstance(self.instance, QuerySet):
137+
return self._get_actual_serializer_from_instance(self.instance).get_fields()
138+
else:
139+
raise Exception("Cannot get fields from a polymorphic serializer given a queryset")
140+
return super(ProjectSerializer, self).get_fields()
141+
114142
def to_representation(self, instance):
115143
# Handle polymorphism
116-
if isinstance(instance, models.ArtProject):
117-
return ArtProjectSerializer(
118-
instance, context=self.context).to_representation(instance)
119-
elif isinstance(instance, models.ResearchProject):
120-
return ResearchProjectSerializer(
121-
instance, context=self.context).to_representation(instance)
122-
return super(ProjectSerializer, self).to_representation(instance)
144+
return self._get_actual_serializer_from_instance(instance).to_representation(instance)
145+
146+
def to_internal_value(self, data):
147+
data_type = data.get('type')
148+
for info in self.polymorphic_serializers:
149+
actual_serializer = info['serializer']
150+
if data_type == utils.get_resource_type_from_serializer(actual_serializer):
151+
self.__class__ = actual_serializer
152+
return actual_serializer(data, context=self.context).to_internal_value(data)
153+
raise Exception("Could not deserialize")
123154

124155

125156
class CompanySerializer(serializers.ModelSerializer):

example/tests/integration/test_polymorphism.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
import random
3+
import json
24
from django.core.urlresolvers import reverse
35

46
from example.tests.utils import load_json
@@ -29,3 +31,41 @@ def test_polymorphism_on_included_relations(single_company, client):
2931
"researchProjects", "artProjects"]
3032
assert [x.get('type') for x in content.get('included')] == ['artProjects', 'artProjects', 'researchProjects'], \
3133
'Detail included types are incorrect'
34+
# Ensure that the child fields are present.
35+
assert content.get('included')[0].get('attributes').get('artist') is not None
36+
assert content.get('included')[1].get('attributes').get('artist') is not None
37+
assert content.get('included')[2].get('attributes').get('supervisor') is not None
38+
39+
def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, client):
40+
url = reverse("project-detail", kwargs={'pk': single_art_project.pk})
41+
response = client.get(url)
42+
content = load_json(response.content)
43+
test_topic = 'test-{}'.format(random.randint(0, 999999))
44+
test_artist = 'test-{}'.format(random.randint(0, 999999))
45+
content['data']['attributes']['topic'] = test_topic
46+
content['data']['attributes']['artist'] = test_artist
47+
response = client.patch(url, data=json.dumps(content), content_type='application/vnd.api+json')
48+
new_content = load_json(response.content)
49+
assert new_content["data"]["type"] == "artProjects"
50+
assert new_content['data']['attributes']['topic'] == test_topic
51+
assert new_content['data']['attributes']['artist'] == test_artist
52+
53+
def test_polymorphism_on_polymorphic_model_list_post(client):
54+
test_topic = 'New test topic {}'.format(random.randint(0, 999999))
55+
test_artist = 'test-{}'.format(random.randint(0, 999999))
56+
url = reverse('project-list')
57+
data = {
58+
'data': {
59+
'type': 'artProjects',
60+
'attributes': {
61+
'topic': test_topic,
62+
'artist': test_artist
63+
}
64+
}
65+
}
66+
response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json')
67+
content = load_json(response.content)
68+
assert content['data']['id'] is not None
69+
assert content["data"]["type"] == "artProjects"
70+
assert content['data']['attributes']['topic'] == test_topic
71+
assert content['data']['attributes']['artist'] == test_artist

rest_framework_json_api/parsers.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Parsers
33
"""
4+
import six
45
from rest_framework import parsers
56
from rest_framework.exceptions import ParseError
67

@@ -72,7 +73,11 @@ def parse(self, stream, media_type=None, parser_context=None):
7273

7374
# Check for inconsistencies
7475
resource_name = utils.get_resource_name(parser_context)
75-
if data.get('type') != resource_name and request.method in ('PUT', 'POST', 'PATCH'):
76+
if isinstance(resource_name, six.string_types):
77+
doesnt_match = data.get('type') != resource_name
78+
else:
79+
doesnt_match = data.get('type') not in resource_name
80+
if doesnt_match and request.method in ('PUT', 'POST', 'PATCH'):
7681
raise exceptions.Conflict(
7782
"The resource object's type ({data_type}) is not the type "
7883
"that constitute the collection represented by the endpoint ({resource_type}).".format(
@@ -82,7 +87,7 @@ def parse(self, stream, media_type=None, parser_context=None):
8287
)
8388

8489
# Construct the return data
85-
parsed_data = {'id': data.get('id')}
90+
parsed_data = {'id': data.get('id'), 'type': data.get('type')}
8691
parsed_data.update(self.parse_attributes(data))
8792
parsed_data.update(self.parse_relationships(data))
8893
return parsed_data

rest_framework_json_api/renderers.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ def extract_included(fields, resource, resource_instance, included_resources):
289289
relation_type = utils.get_resource_type_from_serializer(serializer)
290290
relation_queryset = list(relation_instance_or_manager.all())
291291

292-
# Get the serializer fields
293-
serializer_fields = utils.get_serializer_fields(serializer)
294292
if serializer_data:
295293
for position in range(len(serializer_data)):
296294
serializer_resource = serializer_data[position]
@@ -299,6 +297,7 @@ def extract_included(fields, resource, resource_instance, included_resources):
299297
relation_type or
300298
utils.get_resource_type_from_instance(nested_resource_instance)
301299
)
300+
serializer_fields = utils.get_serializer_fields(serializer.__class__(nested_resource_instance, context=serializer.context))
302301
included_data.append(
303302
JSONRenderer.build_json_resource_obj(
304303
serializer_fields, serializer_resource, nested_resource_instance, resource_type

rest_framework_json_api/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def get_resource_type_from_manager(manager):
231231

232232

233233
def get_resource_type_from_serializer(serializer):
234-
if hasattr(serializer.Meta, 'resource_name'):
234+
if hasattr(serializer, 'polymorphic_serializers'):
235+
return [get_resource_type_from_serializer(s['serializer']) for s in serializer.polymorphic_serializers]
236+
elif hasattr(serializer.Meta, 'resource_name'):
235237
return serializer.Meta.resource_name
236238
else:
237239
return get_resource_type_from_model(serializer.Meta.model)

0 commit comments

Comments
 (0)