diff --git a/.gitignore b/.gitignore
index 3177afc7..fe958047 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,6 +30,8 @@ pip-delete-this-directory.txt
# Tox
.tox/
+.cache/
+.python-version
# VirtualEnv
.venv/
diff --git a/docs/usage.md b/docs/usage.md
index 27caee0c..4a919fc1 100644
--- a/docs/usage.md
+++ b/docs/usage.md
@@ -375,7 +375,7 @@ class LineItemViewSet(viewsets.ModelViewSet):
### RelationshipView
`rest_framework_json_api.views.RelationshipView` is used to build
-relationship views (see the
+relationship views (see the
[JSON API spec](http://jsonapi.org/format/#fetching-relationships)).
The `self` link on a relationship object should point to the corresponding
relationship view.
@@ -423,6 +423,63 @@ field_name_mapping = {
```
+### Working with polymorphic resources
+
+#### Extraction of the polymorphic type
+
+This package can defer the resolution of the type of polymorphic models instances to retrieve the appropriate type.
+However, most models are not polymorphic and for performance reasons this is only done if the underlying model is a subclass of a polymorphic model.
+
+Polymorphic ancestors must be defined on settings like this:
+
+```python
+JSON_API_POLYMORPHIC_ANCESTORS = (
+ 'polymorphic.models.PolymorphicModel',
+)
+```
+
+#### Writing polymorphic resources
+
+A polymorphic endpoint can be setup if associated with a polymorphic serializer.
+A polymorphic serializer take care of (de)serializing the correct instances types and can be defined like this:
+
+```python
+class ProjectSerializer(serializers.PolymorphicModelSerializer):
+ polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer]
+
+ class Meta:
+ model = models.Project
+```
+
+It must inherit from `serializers.PolymorphicModelSerializer` and define the `polymorphic_serializers` list.
+This attribute defines the accepted resource types.
+
+
+Polymorphic relations can also be handled with `relations.PolymorphicResourceRelatedField` like this:
+
+```python
+class CompanySerializer(serializers.ModelSerializer):
+ current_project = relations.PolymorphicResourceRelatedField(
+ ProjectSerializer, queryset=models.Project.objects.all())
+ future_projects = relations.PolymorphicResourceRelatedField(
+ ProjectSerializer, queryset=models.Project.objects.all(), many=True)
+
+ class Meta:
+ model = models.Company
+```
+
+They must be explicitely declared with the `polymorphic_serializer` (first positional argument) correctly defined.
+It must be a subclass of `serializers.PolymorphicModelSerializer`.
+
+
+ Note:
+ Polymorphic resources are not compatible with
+
+ resource_name
+
+ defined on the view.
+
+
### Meta
You may add metadata to the rendered json in two different ways: `meta_fields` and `get_root_meta`.
diff --git a/example/factories/__init__.py b/example/factories/__init__.py
index 0119f925..db74cde3 100644
--- a/example/factories/__init__.py
+++ b/example/factories/__init__.py
@@ -2,21 +2,23 @@
import factory
from faker import Factory as FakerFactory
-from example.models import Blog, Author, AuthorBio, Entry, Comment
+from example import models
+
faker = FakerFactory.create()
faker.seed(983843)
+
class BlogFactory(factory.django.DjangoModelFactory):
class Meta:
- model = Blog
+ model = models.Blog
name = factory.LazyAttribute(lambda x: faker.name())
class AuthorFactory(factory.django.DjangoModelFactory):
class Meta:
- model = Author
+ model = models.Author
name = factory.LazyAttribute(lambda x: faker.name())
email = factory.LazyAttribute(lambda x: faker.email())
@@ -25,7 +27,7 @@ class Meta:
class AuthorBioFactory(factory.django.DjangoModelFactory):
class Meta:
- model = AuthorBio
+ model = models.AuthorBio
author = factory.SubFactory(AuthorFactory)
body = factory.LazyAttribute(lambda x: faker.text())
@@ -33,7 +35,7 @@ class Meta:
class EntryFactory(factory.django.DjangoModelFactory):
class Meta:
- model = Entry
+ model = models.Entry
headline = factory.LazyAttribute(lambda x: faker.sentence(nb_words=4))
body_text = factory.LazyAttribute(lambda x: faker.text())
@@ -52,9 +54,40 @@ def authors(self, create, extracted, **kwargs):
class CommentFactory(factory.django.DjangoModelFactory):
class Meta:
- model = Comment
+ model = models.Comment
entry = factory.SubFactory(EntryFactory)
body = factory.LazyAttribute(lambda x: faker.text())
author = factory.SubFactory(AuthorFactory)
+
+class ArtProjectFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = models.ArtProject
+
+ topic = factory.LazyAttribute(lambda x: faker.catch_phrase())
+ artist = factory.LazyAttribute(lambda x: faker.name())
+
+
+class ResearchProjectFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = models.ResearchProject
+
+ topic = factory.LazyAttribute(lambda x: faker.catch_phrase())
+ supervisor = factory.LazyAttribute(lambda x: faker.name())
+
+
+class CompanyFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = models.Company
+
+ name = factory.LazyAttribute(lambda x: faker.company())
+ current_project = factory.SubFactory(ArtProjectFactory)
+
+ @factory.post_generation
+ def future_projects(self, create, extracted, **kwargs):
+ if not create:
+ return
+ if extracted:
+ for project in extracted:
+ self.future_projects.add(project)
diff --git a/example/migrations/0002_auto_20160513_0857.py b/example/migrations/0002_auto_20160513_0857.py
new file mode 100644
index 00000000..2471ea36
--- /dev/null
+++ b/example/migrations/0002_auto_20160513_0857.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2016-05-13 08:57
+from __future__ import unicode_literals
+from distutils.version import LooseVersion
+
+from django.db import migrations, models
+import django.db.models.deletion
+import django
+
+
+class Migration(migrations.Migration):
+
+ # TODO: Must be removed as soon as Django 1.7 support is dropped
+ if django.get_version() < LooseVersion('1.8'):
+ dependencies = [
+ ('contenttypes', '0001_initial'),
+ ('example', '0001_initial'),
+ ]
+ else:
+ dependencies = [
+ ('contenttypes', '0002_remove_content_type_name'),
+ ('example', '0001_initial'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='Company',
+ fields=[
+ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('name', models.CharField(max_length=100)),
+ ],
+ ),
+ migrations.CreateModel(
+ name='Project',
+ fields=[
+ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('topic', models.CharField(max_length=30)),
+ ],
+ options={
+ 'abstract': False,
+ },
+ ),
+ migrations.CreateModel(
+ name='ArtProject',
+ fields=[
+ ('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')),
+ ('artist', models.CharField(max_length=30)),
+ ],
+ options={
+ 'abstract': False,
+ },
+ bases=('example.project',),
+ ),
+ migrations.CreateModel(
+ name='ResearchProject',
+ fields=[
+ ('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')),
+ ('supervisor', models.CharField(max_length=30)),
+ ],
+ options={
+ 'abstract': False,
+ },
+ bases=('example.project',),
+ ),
+ migrations.AddField(
+ model_name='project',
+ name='polymorphic_ctype',
+ field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_example.project_set+', to='contenttypes.ContentType'),
+ ),
+ migrations.AddField(
+ model_name='company',
+ name='current_project',
+ field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='companies', to='example.Project'),
+ ),
+ migrations.AddField(
+ model_name='company',
+ name='future_projects',
+ field=models.ManyToManyField(to='example.Project'),
+ ),
+ ]
diff --git a/example/models.py b/example/models.py
index 7895722a..6bbaaf1b 100644
--- a/example/models.py
+++ b/example/models.py
@@ -3,6 +3,7 @@
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
+from polymorphic.models import PolymorphicModel
class BaseModel(models.Model):
@@ -72,3 +73,24 @@ class Comment(BaseModel):
def __str__(self):
return self.body
+
+class Project(PolymorphicModel):
+ topic = models.CharField(max_length=30)
+
+
+class ArtProject(Project):
+ artist = models.CharField(max_length=30)
+
+
+class ResearchProject(Project):
+ supervisor = models.CharField(max_length=30)
+
+
+@python_2_unicode_compatible
+class Company(models.Model):
+ name = models.CharField(max_length=100)
+ current_project = models.ForeignKey(Project, related_name='companies')
+ future_projects = models.ManyToManyField(Project)
+
+ def __str__(self):
+ return self.name
diff --git a/example/serializers.py b/example/serializers.py
index e259a10b..09119d1c 100644
--- a/example/serializers.py
+++ b/example/serializers.py
@@ -1,6 +1,6 @@
from datetime import datetime
from rest_framework_json_api import serializers, relations
-from example.models import Blog, Entry, Author, AuthorBio, Comment
+from example import models
class BlogSerializer(serializers.ModelSerializer):
@@ -12,11 +12,11 @@ def get_copyright(self, resource):
def get_root_meta(self, resource, many):
return {
- 'api_docs': '/docs/api/blogs'
+ 'api_docs': '/docs/api/blogs'
}
class Meta:
- model = Blog
+ model = models.Blog
fields = ('name', )
meta_fields = ('copyright',)
@@ -38,27 +38,27 @@ def __init__(self, *args, **kwargs):
}
body_format = serializers.SerializerMethodField()
- # many related from model
+ # Many related from model
comments = relations.ResourceRelatedField(
- source='comment_set', many=True, read_only=True)
- # many related from serializer
+ source='comment_set', many=True, read_only=True)
+ # Many related from serializer
suggested = relations.SerializerMethodResourceRelatedField(
- source='get_suggested', model=Entry, many=True, read_only=True)
- # single related from serializer
+ source='get_suggested', model=models.Entry, many=True, read_only=True)
+ # Single related from serializer
featured = relations.SerializerMethodResourceRelatedField(
- source='get_featured', model=Entry, read_only=True)
+ source='get_featured', model=models.Entry, read_only=True)
def get_suggested(self, obj):
- return Entry.objects.exclude(pk=obj.pk)
+ return models.Entry.objects.exclude(pk=obj.pk)
def get_featured(self, obj):
- return Entry.objects.exclude(pk=obj.pk).first()
+ return models.Entry.objects.exclude(pk=obj.pk).first()
def get_body_format(self, obj):
return 'text'
class Meta:
- model = Entry
+ model = models.Entry
fields = ('blog', 'headline', 'body_text', 'pub_date', 'mod_date',
'authors', 'comments', 'featured', 'suggested',)
meta_fields = ('body_format',)
@@ -67,7 +67,7 @@ class Meta:
class AuthorBioSerializer(serializers.ModelSerializer):
class Meta:
- model = AuthorBio
+ model = models.AuthorBio
fields = ('author', 'body',)
@@ -77,7 +77,7 @@ class AuthorSerializer(serializers.ModelSerializer):
}
class Meta:
- model = Author
+ model = models.Author
fields = ('name', 'email', 'bio')
@@ -88,6 +88,41 @@ class CommentSerializer(serializers.ModelSerializer):
}
class Meta:
- model = Comment
+ model = models.Comment
exclude = ('created_at', 'modified_at',)
# fields = ('entry', 'body', 'author',)
+
+
+class ArtProjectSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = models.ArtProject
+ exclude = ('polymorphic_ctype',)
+
+
+class ResearchProjectSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = models.ResearchProject
+ exclude = ('polymorphic_ctype',)
+
+
+class ProjectSerializer(serializers.PolymorphicModelSerializer):
+ polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer]
+
+ class Meta:
+ model = models.Project
+ exclude = ('polymorphic_ctype',)
+
+
+class CompanySerializer(serializers.ModelSerializer):
+ current_project = relations.PolymorphicResourceRelatedField(
+ ProjectSerializer, queryset=models.Project.objects.all())
+ future_projects = relations.PolymorphicResourceRelatedField(
+ ProjectSerializer, queryset=models.Project.objects.all(), many=True)
+
+ included_serializers = {
+ 'current_project': ProjectSerializer,
+ 'future_projects': ProjectSerializer,
+ }
+
+ class Meta:
+ model = models.Company
diff --git a/example/settings/dev.py b/example/settings/dev.py
index b4b435ca..5a59ba90 100644
--- a/example/settings/dev.py
+++ b/example/settings/dev.py
@@ -23,6 +23,7 @@
'django.contrib.auth',
'django.contrib.admin',
'rest_framework',
+ 'polymorphic',
'example',
]
diff --git a/example/settings/test.py b/example/settings/test.py
index 5bb3f45d..d0157138 100644
--- a/example/settings/test.py
+++ b/example/settings/test.py
@@ -15,3 +15,6 @@
REST_FRAMEWORK.update({
'PAGE_SIZE': 1,
})
+JSON_API_POLYMORPHIC_ANCESTORS = (
+ 'polymorphic.models.PolymorphicModel',
+)
diff --git a/example/tests/conftest.py b/example/tests/conftest.py
index 8a96cfdb..cb059f81 100644
--- a/example/tests/conftest.py
+++ b/example/tests/conftest.py
@@ -1,13 +1,16 @@
import pytest
from pytest_factoryboy import register
-from example.factories import BlogFactory, AuthorFactory, AuthorBioFactory, EntryFactory, CommentFactory
+from example import factories
-register(BlogFactory)
-register(AuthorFactory)
-register(AuthorBioFactory)
-register(EntryFactory)
-register(CommentFactory)
+register(factories.BlogFactory)
+register(factories.AuthorFactory)
+register(factories.AuthorBioFactory)
+register(factories.EntryFactory)
+register(factories.CommentFactory)
+register(factories.ArtProjectFactory)
+register(factories.ResearchProjectFactory)
+register(factories.CompanyFactory)
@pytest.fixture
@@ -29,3 +32,13 @@ def multiple_entries(blog_factory, author_factory, entry_factory, comment_factor
comment_factory(entry=entries[1])
return entries
+
+@pytest.fixture
+def single_company(art_project_factory, research_project_factory, company_factory):
+ company = company_factory(future_projects=(research_project_factory(), art_project_factory()))
+ return company
+
+
+@pytest.fixture
+def single_art_project(art_project_factory):
+ return art_project_factory()
diff --git a/example/tests/integration/test_polymorphism.py b/example/tests/integration/test_polymorphism.py
new file mode 100644
index 00000000..5b7fbb7b
--- /dev/null
+++ b/example/tests/integration/test_polymorphism.py
@@ -0,0 +1,136 @@
+import pytest
+import random
+import json
+from django.core.urlresolvers import reverse
+
+from example.tests.utils import load_json
+
+pytestmark = pytest.mark.django_db
+
+
+def test_polymorphism_on_detail(single_art_project, client):
+ response = client.get(reverse("project-detail", kwargs={'pk': single_art_project.pk}))
+ content = load_json(response.content)
+ assert content["data"]["type"] == "artProjects"
+
+
+def test_polymorphism_on_detail_relations(single_company, client):
+ response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk}))
+ content = load_json(response.content)
+ assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects"
+ assert [rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]] == [
+ "researchProjects", "artProjects"]
+
+
+def test_polymorphism_on_included_relations(single_company, client):
+ response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk}) +
+ '?include=current_project,future_projects')
+ content = load_json(response.content)
+ assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects"
+ assert [rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]] == [
+ "researchProjects", "artProjects"]
+ assert [x.get('type') for x in content.get('included')] == [
+ 'artProjects', 'artProjects', 'researchProjects'], 'Detail included types are incorrect'
+ # Ensure that the child fields are present.
+ assert content.get('included')[0].get('attributes').get('artist') is not None
+ assert content.get('included')[1].get('attributes').get('artist') is not None
+ assert content.get('included')[2].get('attributes').get('supervisor') is not None
+
+
+def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, client):
+ url = reverse("project-detail", kwargs={'pk': single_art_project.pk})
+ response = client.get(url)
+ content = load_json(response.content)
+ test_topic = 'test-{}'.format(random.randint(0, 999999))
+ test_artist = 'test-{}'.format(random.randint(0, 999999))
+ content['data']['attributes']['topic'] = test_topic
+ content['data']['attributes']['artist'] = test_artist
+ response = client.patch(url, data=json.dumps(content), content_type='application/vnd.api+json')
+ new_content = load_json(response.content)
+ assert new_content["data"]["type"] == "artProjects"
+ assert new_content['data']['attributes']['topic'] == test_topic
+ assert new_content['data']['attributes']['artist'] == test_artist
+
+
+def test_polymorphism_on_polymorphic_model_list_post(client):
+ test_topic = 'New test topic {}'.format(random.randint(0, 999999))
+ test_artist = 'test-{}'.format(random.randint(0, 999999))
+ url = reverse('project-list')
+ data = {
+ 'data': {
+ 'type': 'artProjects',
+ 'attributes': {
+ 'topic': test_topic,
+ 'artist': test_artist
+ }
+ }
+ }
+ response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json')
+ content = load_json(response.content)
+ assert content['data']['id'] is not None
+ assert content["data"]["type"] == "artProjects"
+ assert content['data']['attributes']['topic'] == test_topic
+ assert content['data']['attributes']['artist'] == test_artist
+
+
+def test_invalid_type_on_polymorphic_model(client):
+ test_topic = 'New test topic {}'.format(random.randint(0, 999999))
+ test_artist = 'test-{}'.format(random.randint(0, 999999))
+ url = reverse('project-list')
+ data = {
+ 'data': {
+ 'type': 'invalidProjects',
+ 'attributes': {
+ 'topic': test_topic,
+ 'artist': test_artist
+ }
+ }
+ }
+ response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json')
+ assert response.status_code == 409
+ content = load_json(response.content)
+ assert len(content["errors"]) is 1
+ assert content["errors"][0]["status"] == "409"
+ assert content["errors"][0]["detail"] == \
+ "The resource object's type (invalidProjects) is not the type that constitute the " \
+ "collection represented by the endpoint (one of [researchProjects, artProjects])."
+
+
+def test_polymorphism_relations_update(single_company, research_project_factory, client):
+ response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk}))
+ content = load_json(response.content)
+ assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects"
+
+ research_project = research_project_factory()
+ content["data"]["relationships"]["currentProject"]["data"] = {
+ "type": "researchProjects",
+ "id": research_project.pk
+ }
+ response = client.put(reverse("company-detail", kwargs={'pk': single_company.pk}),
+ data=json.dumps(content), content_type='application/vnd.api+json')
+ assert response.status_code == 200
+ content = load_json(response.content)
+ assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "researchProjects"
+ assert int(content["data"]["relationships"]["currentProject"]["data"]["id"]) == \
+ research_project.pk
+
+
+def test_invalid_type_on_polymorphic_relation(single_company, research_project_factory, client):
+ response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk}))
+ content = load_json(response.content)
+ assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects"
+
+ research_project = research_project_factory()
+ content["data"]["relationships"]["currentProject"]["data"] = {
+ "type": "invalidProjects",
+ "id": research_project.pk
+ }
+ response = client.put(reverse("company-detail", kwargs={'pk': single_company.pk}),
+ data=json.dumps(content), content_type='application/vnd.api+json')
+ assert response.status_code == 409
+ content = load_json(response.content)
+ assert len(content["errors"]) is 1
+ assert content["errors"][0]["status"] == "409"
+ assert content["errors"][0]["detail"] == \
+ "Incorrect relation type. Expected one of [researchProjects, artProjects], " \
+ "received invalidProjects."
diff --git a/example/urls.py b/example/urls.py
index f48135c7..4443960f 100644
--- a/example/urls.py
+++ b/example/urls.py
@@ -1,7 +1,8 @@
from django.conf.urls import include, url
from rest_framework import routers
-from example.views import BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet
+from example.views import (
+ BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, CompanyViewset, ProjectViewset)
router = routers.DefaultRouter(trailing_slash=False)
@@ -9,6 +10,8 @@
router.register(r'entries', EntryViewSet)
router.register(r'authors', AuthorViewSet)
router.register(r'comments', CommentViewSet)
+router.register(r'companies', CompanyViewset)
+router.register(r'projects', ProjectViewset)
urlpatterns = [
url(r'^', include(router.urls)),
diff --git a/example/urls_test.py b/example/urls_test.py
index 0f8ed73b..21f29fd1 100644
--- a/example/urls_test.py
+++ b/example/urls_test.py
@@ -1,8 +1,9 @@
from django.conf.urls import include, url
from rest_framework import routers
-from example.views import BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, EntryRelationshipView, BlogRelationshipView, \
- CommentRelationshipView, AuthorRelationshipView
+from example.views import (
+ BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, CompanyViewset, ProjectViewset,
+ EntryRelationshipView, BlogRelationshipView, CommentRelationshipView, AuthorRelationshipView)
from .api.resources.identity import Identity, GenericIdentity
router = routers.DefaultRouter(trailing_slash=False)
@@ -11,6 +12,8 @@
router.register(r'entries', EntryViewSet)
router.register(r'authors', AuthorViewSet)
router.register(r'comments', CommentViewSet)
+router.register(r'companies', CompanyViewset)
+router.register(r'projects', ProjectViewset)
# for the old tests
router.register(r'identities', Identity)
@@ -36,4 +39,3 @@
AuthorRelationshipView.as_view(),
name='author-relationships'),
]
-
diff --git a/example/views.py b/example/views.py
index 988cda66..e32db8c0 100644
--- a/example/views.py
+++ b/example/views.py
@@ -6,9 +6,10 @@
import rest_framework_json_api.parsers
import rest_framework_json_api.renderers
from rest_framework_json_api.views import RelationshipView
-from example.models import Blog, Entry, Author, Comment
+from example.models import Blog, Entry, Author, Comment, Company, Project
from example.serializers import (
- BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer)
+ BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer, CompanySerializer,
+ ProjectSerializer)
from rest_framework_json_api.utils import format_drf_errors
@@ -72,6 +73,16 @@ class CommentViewSet(viewsets.ModelViewSet):
serializer_class = CommentSerializer
+class CompanyViewset(viewsets.ModelViewSet):
+ queryset = Company.objects.all()
+ serializer_class = CompanySerializer
+
+
+class ProjectViewset(viewsets.ModelViewSet):
+ queryset = Project.objects.all()
+ serializer_class = ProjectSerializer
+
+
class EntryRelationshipView(RelationshipView):
queryset = Entry.objects.all()
diff --git a/requirements-development.txt b/requirements-development.txt
index 6aa243bd..b5e25321 100644
--- a/requirements-development.txt
+++ b/requirements-development.txt
@@ -3,4 +3,5 @@ pytest==2.8.2
pytest-django
pytest-factoryboy
fake-factory
+django-polymorphic
tox
diff --git a/rest_framework_json_api/parsers.py b/rest_framework_json_api/parsers.py
index 30b9ad0e..f1cd6abe 100644
--- a/rest_framework_json_api/parsers.py
+++ b/rest_framework_json_api/parsers.py
@@ -1,6 +1,7 @@
"""
Parsers
"""
+import six
from rest_framework import parsers
from rest_framework.exceptions import ParseError
@@ -29,7 +30,8 @@ class JSONParser(parsers.JSONParser):
@staticmethod
def parse_attributes(data):
- return utils.format_keys(data.get('attributes'), 'underscore') if data.get('attributes') else dict()
+ return utils.format_keys(
+ data.get('attributes'), 'underscore') if data.get('attributes') else dict()
@staticmethod
def parse_relationships(data):
@@ -50,39 +52,52 @@ def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as JSON and returns the resulting data
"""
- result = super(JSONParser, self).parse(stream, media_type=media_type, parser_context=parser_context)
+ result = super(JSONParser, self).parse(
+ stream, media_type=media_type, parser_context=parser_context)
data = result.get('data')
if data:
from rest_framework_json_api.views import RelationshipView
if isinstance(parser_context['view'], RelationshipView):
- # We skip parsing the object as JSONAPI Resource Identifier Object and not a regular Resource Object
+ # We skip parsing the object as JSONAPI Resource Identifier Object is not a
+ # regular Resource Object
if isinstance(data, list):
for resource_identifier_object in data:
- if not (resource_identifier_object.get('id') and resource_identifier_object.get('type')):
- raise ParseError(
- 'Received data contains one or more malformed JSONAPI Resource Identifier Object(s)'
- )
+ if not (resource_identifier_object.get('id') and
+ resource_identifier_object.get('type')):
+ raise ParseError('Received data contains one or more malformed '
+ 'JSONAPI Resource Identifier Object(s)')
elif not (data.get('id') and data.get('type')):
- raise ParseError('Received data is not a valid JSONAPI Resource Identifier Object')
+ raise ParseError('Received data is not a valid '
+ 'JSONAPI Resource Identifier Object')
return data
request = parser_context.get('request')
# Check for inconsistencies
- resource_name = utils.get_resource_name(parser_context)
- if data.get('type') != resource_name and request.method in ('PUT', 'POST', 'PATCH'):
- raise exceptions.Conflict(
- "The resource object's type ({data_type}) is not the type "
- "that constitute the collection represented by the endpoint ({resource_type}).".format(
- data_type=data.get('type'),
- resource_type=resource_name
- )
- )
+ if request.method in ('PUT', 'POST', 'PATCH'):
+ resource_name = utils.get_resource_name(
+ parser_context, expand_polymorphic_types=True)
+ if isinstance(resource_name, six.string_types):
+ if data.get('type') != resource_name:
+ raise exceptions.Conflict(
+ "The resource object's type ({data_type}) is not the type that "
+ "constitute the collection represented by the endpoint "
+ "({resource_type}).".format(
+ data_type=data.get('type'),
+ resource_type=resource_name))
+ else:
+ if data.get('type') not in resource_name:
+ raise exceptions.Conflict(
+ "The resource object's type ({data_type}) is not the type that "
+ "constitute the collection represented by the endpoint "
+ "(one of [{resource_types}]).".format(
+ data_type=data.get('type'),
+ resource_types=", ".join(resource_name)))
# Construct the return data
- parsed_data = {'id': data.get('id')}
+ parsed_data = {'id': data.get('id'), 'type': data.get('type')}
parsed_data.update(self.parse_attributes(data))
parsed_data.update(self.parse_relationships(data))
return parsed_data
diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py
index 0e6594d5..471217ea 100644
--- a/rest_framework_json_api/relations.py
+++ b/rest_framework_json_api/relations.py
@@ -6,12 +6,13 @@
from django.db.models.query import QuerySet
from rest_framework_json_api.exceptions import Conflict
-from rest_framework_json_api.utils import Hyperlink, \
+from rest_framework_json_api.utils import POLYMORPHIC_ANCESTORS, Hyperlink, \
get_resource_type_from_queryset, get_resource_type_from_instance, \
get_included_serializers, get_resource_type_from_serializer
class ResourceRelatedField(PrimaryKeyRelatedField):
+ _skip_polymorphic_optimization = True
self_link_view_name = None
related_link_view_name = None
related_link_lookup_field = 'pk'
@@ -47,6 +48,12 @@ def __init__(self, self_link_view_name=None, related_link_view_name=None, **kwar
super(ResourceRelatedField, self).__init__(**kwargs)
+ # Determine if relation is polymorphic
+ self.is_polymorphic = False
+ model = model or getattr(self.get_queryset(), 'model', None)
+ if model and issubclass(model, POLYMORPHIC_ANCESTORS):
+ self.is_polymorphic = True
+
def use_pk_only_optimization(self):
# We need the real object to determine its type...
return False
@@ -129,7 +136,8 @@ def to_internal_value(self, data):
self.fail('missing_id')
if data['type'] != expected_relation_type:
- self.conflict('incorrect_relation_type', relation_type=expected_relation_type, received_type=data['type'])
+ self.conflict('incorrect_relation_type', relation_type=expected_relation_type,
+ received_type=data['type'])
return super(ResourceRelatedField, self).to_internal_value(data['id'])
@@ -144,7 +152,8 @@ def to_representation(self, value):
resource_type = None
root = getattr(self.parent, 'parent', self.parent)
field_name = self.field_name if self.field_name else self.parent.field_name
- if getattr(root, 'included_serializers', None) is not None:
+ if getattr(root, 'included_serializers', None) is not None and \
+ self._skip_polymorphic_optimization:
includes = get_included_serializers(root)
if field_name in includes.keys():
resource_type = get_resource_type_from_serializer(includes[field_name])
@@ -169,6 +178,42 @@ def choices(self):
])
+class PolymorphicResourceRelatedField(ResourceRelatedField):
+
+ _skip_polymorphic_optimization = False
+ default_error_messages = dict(ResourceRelatedField.default_error_messages, **{
+ 'incorrect_relation_type': _('Incorrect relation type. Expected one of [{relation_type}], '
+ 'received {received_type}.'),
+ })
+
+ def __init__(self, polymorphic_serializer, *args, **kwargs):
+ self.polymorphic_serializer = polymorphic_serializer
+ super(PolymorphicResourceRelatedField, self).__init__(*args, **kwargs)
+
+ def to_internal_value(self, data):
+ if isinstance(data, six.text_type):
+ try:
+ data = json.loads(data)
+ except ValueError:
+ # show a useful error if they send a `pk` instead of resource object
+ self.fail('incorrect_type', data_type=type(data).__name__)
+ if not isinstance(data, dict):
+ self.fail('incorrect_type', data_type=type(data).__name__)
+
+ if 'type' not in data:
+ self.fail('missing_type')
+
+ if 'id' not in data:
+ self.fail('missing_id')
+
+ expected_relation_types = self.polymorphic_serializer.get_polymorphic_types()
+
+ if data['type'] not in expected_relation_types:
+ self.conflict('incorrect_relation_type', relation_type=", ".join(
+ expected_relation_types), received_type=data['type'])
+
+ return super(ResourceRelatedField, self).to_internal_value(data['id'])
+
class SerializerMethodResourceRelatedField(ResourceRelatedField):
"""
diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py
index 1c66c927..16531fb0 100644
--- a/rest_framework_json_api/renderers.py
+++ b/rest_framework_json_api/renderers.py
@@ -289,8 +289,6 @@ def extract_included(fields, resource, resource_instance, included_resources):
relation_type = utils.get_resource_type_from_serializer(serializer)
relation_queryset = list(relation_instance_or_manager.all())
- # Get the serializer fields
- serializer_fields = utils.get_serializer_fields(serializer)
if serializer_data:
for position in range(len(serializer_data)):
serializer_resource = serializer_data[position]
@@ -299,6 +297,7 @@ def extract_included(fields, resource, resource_instance, included_resources):
relation_type or
utils.get_resource_type_from_instance(nested_resource_instance)
)
+ serializer_fields = utils.get_serializer_fields(serializer.__class__(nested_resource_instance, context=serializer.context))
included_data.append(
JSONRenderer.build_json_resource_obj(
serializer_fields, serializer_resource, nested_resource_instance, resource_type
@@ -360,6 +359,9 @@ def extract_root_meta(serializer, resource):
@staticmethod
def build_json_resource_obj(fields, resource, resource_instance, resource_name):
+ # Determine type from the instance if the underlying model is polymorphic
+ if isinstance(resource_instance, utils.POLYMORPHIC_ANCESTORS):
+ resource_name = utils.get_resource_type_from_instance(resource_instance)
resource_data = [
('type', resource_name),
('id', encoding.force_text(resource_instance.pk) if resource_instance else None),
diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py
index 953c4437..a5457e1c 100644
--- a/rest_framework_json_api/serializers.py
+++ b/rest_framework_json_api/serializers.py
@@ -1,8 +1,11 @@
+from django.db.models.query import QuerySet
from django.utils.translation import ugettext_lazy as _
+from django.utils import six
from rest_framework.exceptions import ParseError
from rest_framework.serializers import *
from rest_framework_json_api.relations import ResourceRelatedField
+from rest_framework_json_api.exceptions import Conflict
from rest_framework_json_api.utils import (
get_resource_type_from_model, get_resource_type_from_instance,
get_resource_type_from_serializer, get_included_serializers)
@@ -10,7 +13,8 @@
class ResourceIdentifierObjectSerializer(BaseSerializer):
default_error_messages = {
- 'incorrect_model_type': _('Incorrect model type. Expected {model_type}, received {received_type}.'),
+ 'incorrect_model_type': _('Incorrect model type. Expected {model_type}, '
+ 'received {received_type}.'),
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}
@@ -20,7 +24,8 @@ class ResourceIdentifierObjectSerializer(BaseSerializer):
def __init__(self, *args, **kwargs):
self.model_class = kwargs.pop('model_class', self.model_class)
if 'instance' not in kwargs and not self.model_class:
- raise RuntimeError('ResourceIdentifierObjectsSerializer must be initialized with a model class.')
+ raise RuntimeError(
+ 'ResourceIdentifierObjectsSerializer must be initialized with a model class.')
super(ResourceIdentifierObjectSerializer, self).__init__(*args, **kwargs)
def to_representation(self, instance):
@@ -31,7 +36,8 @@ def to_representation(self, instance):
def to_internal_value(self, data):
if data['type'] != get_resource_type_from_model(self.model_class):
- self.fail('incorrect_model_type', model_type=self.model_class, received_type=data['type'])
+ self.fail(
+ 'incorrect_model_type', model_type=self.model_class, received_type=data['type'])
pk = data['id']
try:
return self.model_class.objects.get(pk=pk)
@@ -47,15 +53,18 @@ def __init__(self, *args, **kwargs):
request = context.get('request') if context else None
if request:
- sparse_fieldset_query_param = 'fields[{}]'.format(get_resource_type_from_serializer(self))
+ sparse_fieldset_query_param = 'fields[{}]'.format(
+ get_resource_type_from_serializer(self))
try:
- param_name = next(key for key in request.query_params if sparse_fieldset_query_param in key)
+ param_name = next(
+ key for key in request.query_params if sparse_fieldset_query_param in key)
except StopIteration:
pass
else:
fieldset = request.query_params.get(param_name).split(',')
- # iterate over a *copy* of self.fields' underlying OrderedDict, because we may modify the
- # original during the iteration. self.fields is a `rest_framework.utils.serializer_helpers.BindingDict`
+ # Iterate over a *copy* of self.fields' underlying OrderedDict, because we may
+ # modify the original during the iteration.
+ # self.fields is a `rest_framework.utils.serializer_helpers.BindingDict`
for field_name, field in self.fields.fields.copy().items():
if field_name == api_settings.URL_FIELD_NAME: # leave self link there
continue
@@ -101,7 +110,8 @@ def validate_path(serializer_class, field_path, path):
super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs)
-class HyperlinkedModelSerializer(IncludedResourcesValidationMixin, SparseFieldsetsMixin, HyperlinkedModelSerializer):
+class HyperlinkedModelSerializer(IncludedResourcesValidationMixin, SparseFieldsetsMixin,
+ HyperlinkedModelSerializer):
"""
A type of `ModelSerializer` that uses hyperlinked relationships instead
of primary key relationships. Specifically:
@@ -152,3 +162,134 @@ def get_field_names(self, declared_fields, info):
declared[field_name] = field
fields = super(ModelSerializer, self).get_field_names(declared, info)
return list(fields) + list(getattr(self.Meta, 'meta_fields', list()))
+
+
+class PolymorphicSerializerMetaclass(SerializerMetaclass):
+ """
+ This metaclass ensures that the `polymorphic_serializers` is correctly defined on a
+ `PolymorphicSerializer` class and make a cache of model/serializer/type mappings.
+ """
+
+ def __new__(cls, name, bases, attrs):
+ new_class = super(PolymorphicSerializerMetaclass, cls).__new__(cls, name, bases, attrs)
+
+ # Ensure initialization is only performed for subclasses of PolymorphicModelSerializer
+ # (excluding PolymorphicModelSerializer class itself).
+ parents = [b for b in bases if isinstance(b, PolymorphicSerializerMetaclass)]
+ if not parents:
+ return new_class
+
+ polymorphic_serializers = getattr(new_class, 'polymorphic_serializers', None)
+ if not polymorphic_serializers:
+ raise NotImplementedError(
+ "A PolymorphicModelSerializer must define a `polymorphic_serializers` attribute.")
+ serializer_to_model = {
+ serializer: serializer.Meta.model for serializer in polymorphic_serializers}
+ model_to_serializer = {
+ serializer.Meta.model: serializer for serializer in polymorphic_serializers}
+ type_to_serializer = {
+ get_resource_type_from_serializer(serializer): serializer for
+ serializer in polymorphic_serializers}
+ setattr(new_class, '_poly_serializer_model_map', serializer_to_model)
+ setattr(new_class, '_poly_model_serializer_map', model_to_serializer)
+ setattr(new_class, '_poly_type_serializer_map', type_to_serializer)
+ return new_class
+
+
+@six.add_metaclass(PolymorphicSerializerMetaclass)
+class PolymorphicModelSerializer(ModelSerializer):
+ """
+ A serializer for polymorphic models.
+ Useful for "lazy" parent models. Leaves should be represented with a regular serializer.
+ """
+ def get_fields(self):
+ """
+ Return an exhaustive list of the polymorphic serializer fields.
+ """
+ if self.instance is not None:
+ if not isinstance(self.instance, QuerySet):
+ serializer_class = self.get_polymorphic_serializer_for_instance(self.instance)
+ return serializer_class(self.instance, context=self.context).get_fields()
+ else:
+ raise Exception("Cannot get fields from a polymorphic serializer given a queryset")
+ return super(PolymorphicModelSerializer, self).get_fields()
+
+ @classmethod
+ def get_polymorphic_serializer_for_instance(cls, instance):
+ """
+ Return the polymorphic serializer associated with the given instance/model.
+ Raise `NotImplementedError` if no serializer is found for the given model. This usually
+ means that a serializer is missing in the class's `polymorphic_serializers` attribute.
+ """
+ try:
+ return cls._poly_model_serializer_map[instance._meta.model]
+ except KeyError:
+ raise NotImplementedError(
+ "No polymorphic serializer has been found for model {}".format(
+ instance._meta.model.__name__))
+
+ @classmethod
+ def get_polymorphic_model_for_serializer(cls, serializer):
+ """
+ Return the polymorphic model associated with the given serializer.
+ Raise `NotImplementedError` if no model is found for the given serializer. This usually
+ means that a serializer is missing in the class's `polymorphic_serializers` attribute.
+ """
+ try:
+ return cls._poly_serializer_model_map[serializer]
+ except KeyError:
+ raise NotImplementedError(
+ "No polymorphic model has been found for serializer {}".format(serializer.__name__))
+
+ @classmethod
+ def get_polymorphic_serializer_for_type(cls, obj_type):
+ """
+ Return the polymorphic serializer associated with the given type.
+ Raise `NotImplementedError` if no serializer is found for the given type. This usually
+ means that a serializer is missing in the class's `polymorphic_serializers` attribute.
+ """
+ try:
+ return cls._poly_type_serializer_map[obj_type]
+ except KeyError:
+ raise NotImplementedError(
+ "No polymorphic serializer has been found for type {}".format(obj_type))
+
+ @classmethod
+ def get_polymorphic_model_for_type(cls, obj_type):
+ """
+ Return the polymorphic model associated with the given type.
+ Raise `NotImplementedError` if no model is found for the given type. This usually
+ means that a serializer is missing in the class's `polymorphic_serializers` attribute.
+ """
+ return cls.get_polymorphic_model_for_serializer(
+ cls.get_polymorphic_serializer_for_type(obj_type))
+
+ @classmethod
+ def get_polymorphic_types(cls):
+ """
+ Return the list of accepted types.
+ """
+ return cls._poly_type_serializer_map.keys()
+
+ def to_representation(self, instance):
+ """
+ Retrieve the appropriate polymorphic serializer and use this to handle representation.
+ """
+ serializer_class = self.get_polymorphic_serializer_for_instance(instance)
+ return serializer_class(instance, context=self.context).to_representation(instance)
+
+ def to_internal_value(self, data):
+ """
+ Ensure that the given type is one of the expected polymorphic types, then retrieve the
+ appropriate polymorphic serializer and use this to handle internal value.
+ """
+ received_type = data.get('type')
+ expected_types = self.get_polymorphic_types()
+ if received_type not in expected_types:
+ raise Conflict(
+ 'Incorrect relation type. Expected on of [{expected_types}], '
+ 'received {received_type}.'.format(
+ expected_types=', '.join(expected_types), received_type=received_type))
+ serializer_class = self.get_polymorphic_serializer_for_type(received_type)
+ self.__class__ = serializer_class
+ return serializer_class(data, context=self.context).to_internal_value(data)
diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py
index 261640c6..95e036ca 100644
--- a/rest_framework_json_api/utils.py
+++ b/rest_framework_json_api/utils.py
@@ -25,11 +25,17 @@
except ImportError:
HyperlinkedRouterField = type(None)
+POLYMORPHIC_ANCESTORS = ()
+for ancestor in getattr(settings, 'JSON_API_POLYMORPHIC_ANCESTORS', ()):
+ ancestor_class = import_class_from_dotted_path(ancestor)
+ POLYMORPHIC_ANCESTORS += (ancestor_class,)
-def get_resource_name(context):
+
+def get_resource_name(context, expand_polymorphic_types=False):
"""
Return the name of a resource.
"""
+ from . import serializers
view = context.get('view')
# Sanity check to make sure we have a view.
@@ -51,7 +57,11 @@ def get_resource_name(context):
except AttributeError:
try:
serializer = view.get_serializer_class()
- return get_resource_type_from_serializer(serializer)
+ if issubclass(serializer, serializers.PolymorphicModelSerializer) and \
+ expand_polymorphic_types:
+ return serializer.get_polymorphic_types()
+ else:
+ return get_resource_type_from_serializer(serializer)
except AttributeError:
try:
resource_name = get_resource_type_from_model(view.model)
@@ -86,6 +96,7 @@ def get_serializer_fields(serializer):
pass
return fields
+
def format_keys(obj, format_type=None):
"""
Takes either a dict or list and returns it with camelized keys only if
@@ -141,12 +152,15 @@ def format_value(value, format_type=None):
def format_relation_name(value, format_type=None):
- warnings.warn("The 'format_relation_name' function has been renamed 'format_resource_type' and the settings are now 'JSON_API_FORMAT_TYPES' and 'JSON_API_PLURALIZE_TYPES'")
+ warnings.warn(
+ "The 'format_relation_name' function has been renamed 'format_resource_type' and "
+ "the settings are now 'JSON_API_FORMAT_TYPES' and 'JSON_API_PLURALIZE_TYPES'")
if format_type is None:
format_type = getattr(settings, 'JSON_API_FORMAT_RELATION_KEYS', None)
pluralize = getattr(settings, 'JSON_API_PLURALIZE_RELATION_TYPE', None)
return format_resource_type(value, format_type, pluralize)
+
def format_resource_type(value, format_type=None, pluralize=None):
if format_type is None:
format_type = getattr(settings, 'JSON_API_FORMAT_TYPES', False)