diff --git a/.gitignore b/.gitignore index 3177afc7..fe958047 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,8 @@ pip-delete-this-directory.txt # Tox .tox/ +.cache/ +.python-version # VirtualEnv .venv/ diff --git a/docs/usage.md b/docs/usage.md index 27caee0c..4a919fc1 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -375,7 +375,7 @@ class LineItemViewSet(viewsets.ModelViewSet): ### RelationshipView `rest_framework_json_api.views.RelationshipView` is used to build -relationship views (see the +relationship views (see the [JSON API spec](http://jsonapi.org/format/#fetching-relationships)). The `self` link on a relationship object should point to the corresponding relationship view. @@ -423,6 +423,63 @@ field_name_mapping = { ``` +### Working with polymorphic resources + +#### Extraction of the polymorphic type + +This package can defer the resolution of the type of polymorphic models instances to retrieve the appropriate type. +However, most models are not polymorphic and for performance reasons this is only done if the underlying model is a subclass of a polymorphic model. + +Polymorphic ancestors must be defined on settings like this: + +```python +JSON_API_POLYMORPHIC_ANCESTORS = ( + 'polymorphic.models.PolymorphicModel', +) +``` + +#### Writing polymorphic resources + +A polymorphic endpoint can be setup if associated with a polymorphic serializer. +A polymorphic serializer take care of (de)serializing the correct instances types and can be defined like this: + +```python +class ProjectSerializer(serializers.PolymorphicModelSerializer): + polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer] + + class Meta: + model = models.Project +``` + +It must inherit from `serializers.PolymorphicModelSerializer` and define the `polymorphic_serializers` list. +This attribute defines the accepted resource types. + + +Polymorphic relations can also be handled with `relations.PolymorphicResourceRelatedField` like this: + +```python +class CompanySerializer(serializers.ModelSerializer): + current_project = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=models.Project.objects.all()) + future_projects = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=models.Project.objects.all(), many=True) + + class Meta: + model = models.Company +``` + +They must be explicitely declared with the `polymorphic_serializer` (first positional argument) correctly defined. +It must be a subclass of `serializers.PolymorphicModelSerializer`. + +
+ Note: + Polymorphic resources are not compatible with + + resource_name + + defined on the view. +
+ ### Meta You may add metadata to the rendered json in two different ways: `meta_fields` and `get_root_meta`. diff --git a/example/factories/__init__.py b/example/factories/__init__.py index 0119f925..db74cde3 100644 --- a/example/factories/__init__.py +++ b/example/factories/__init__.py @@ -2,21 +2,23 @@ import factory from faker import Factory as FakerFactory -from example.models import Blog, Author, AuthorBio, Entry, Comment +from example import models + faker = FakerFactory.create() faker.seed(983843) + class BlogFactory(factory.django.DjangoModelFactory): class Meta: - model = Blog + model = models.Blog name = factory.LazyAttribute(lambda x: faker.name()) class AuthorFactory(factory.django.DjangoModelFactory): class Meta: - model = Author + model = models.Author name = factory.LazyAttribute(lambda x: faker.name()) email = factory.LazyAttribute(lambda x: faker.email()) @@ -25,7 +27,7 @@ class Meta: class AuthorBioFactory(factory.django.DjangoModelFactory): class Meta: - model = AuthorBio + model = models.AuthorBio author = factory.SubFactory(AuthorFactory) body = factory.LazyAttribute(lambda x: faker.text()) @@ -33,7 +35,7 @@ class Meta: class EntryFactory(factory.django.DjangoModelFactory): class Meta: - model = Entry + model = models.Entry headline = factory.LazyAttribute(lambda x: faker.sentence(nb_words=4)) body_text = factory.LazyAttribute(lambda x: faker.text()) @@ -52,9 +54,40 @@ def authors(self, create, extracted, **kwargs): class CommentFactory(factory.django.DjangoModelFactory): class Meta: - model = Comment + model = models.Comment entry = factory.SubFactory(EntryFactory) body = factory.LazyAttribute(lambda x: faker.text()) author = factory.SubFactory(AuthorFactory) + +class ArtProjectFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.ArtProject + + topic = factory.LazyAttribute(lambda x: faker.catch_phrase()) + artist = factory.LazyAttribute(lambda x: faker.name()) + + +class ResearchProjectFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.ResearchProject + + topic = factory.LazyAttribute(lambda x: faker.catch_phrase()) + supervisor = factory.LazyAttribute(lambda x: faker.name()) + + +class CompanyFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.Company + + name = factory.LazyAttribute(lambda x: faker.company()) + current_project = factory.SubFactory(ArtProjectFactory) + + @factory.post_generation + def future_projects(self, create, extracted, **kwargs): + if not create: + return + if extracted: + for project in extracted: + self.future_projects.add(project) diff --git a/example/migrations/0002_auto_20160513_0857.py b/example/migrations/0002_auto_20160513_0857.py new file mode 100644 index 00000000..2471ea36 --- /dev/null +++ b/example/migrations/0002_auto_20160513_0857.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.6 on 2016-05-13 08:57 +from __future__ import unicode_literals +from distutils.version import LooseVersion + +from django.db import migrations, models +import django.db.models.deletion +import django + + +class Migration(migrations.Migration): + + # TODO: Must be removed as soon as Django 1.7 support is dropped + if django.get_version() < LooseVersion('1.8'): + dependencies = [ + ('contenttypes', '0001_initial'), + ('example', '0001_initial'), + ] + else: + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + ('example', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Company', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ], + ), + migrations.CreateModel( + name='Project', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('topic', models.CharField(max_length=30)), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='ArtProject', + fields=[ + ('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')), + ('artist', models.CharField(max_length=30)), + ], + options={ + 'abstract': False, + }, + bases=('example.project',), + ), + migrations.CreateModel( + name='ResearchProject', + fields=[ + ('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')), + ('supervisor', models.CharField(max_length=30)), + ], + options={ + 'abstract': False, + }, + bases=('example.project',), + ), + migrations.AddField( + model_name='project', + name='polymorphic_ctype', + field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_example.project_set+', to='contenttypes.ContentType'), + ), + migrations.AddField( + model_name='company', + name='current_project', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='companies', to='example.Project'), + ), + migrations.AddField( + model_name='company', + name='future_projects', + field=models.ManyToManyField(to='example.Project'), + ), + ] diff --git a/example/models.py b/example/models.py index 7895722a..6bbaaf1b 100644 --- a/example/models.py +++ b/example/models.py @@ -3,6 +3,7 @@ from django.db import models from django.utils.encoding import python_2_unicode_compatible +from polymorphic.models import PolymorphicModel class BaseModel(models.Model): @@ -72,3 +73,24 @@ class Comment(BaseModel): def __str__(self): return self.body + +class Project(PolymorphicModel): + topic = models.CharField(max_length=30) + + +class ArtProject(Project): + artist = models.CharField(max_length=30) + + +class ResearchProject(Project): + supervisor = models.CharField(max_length=30) + + +@python_2_unicode_compatible +class Company(models.Model): + name = models.CharField(max_length=100) + current_project = models.ForeignKey(Project, related_name='companies') + future_projects = models.ManyToManyField(Project) + + def __str__(self): + return self.name diff --git a/example/serializers.py b/example/serializers.py index e259a10b..09119d1c 100644 --- a/example/serializers.py +++ b/example/serializers.py @@ -1,6 +1,6 @@ from datetime import datetime from rest_framework_json_api import serializers, relations -from example.models import Blog, Entry, Author, AuthorBio, Comment +from example import models class BlogSerializer(serializers.ModelSerializer): @@ -12,11 +12,11 @@ def get_copyright(self, resource): def get_root_meta(self, resource, many): return { - 'api_docs': '/docs/api/blogs' + 'api_docs': '/docs/api/blogs' } class Meta: - model = Blog + model = models.Blog fields = ('name', ) meta_fields = ('copyright',) @@ -38,27 +38,27 @@ def __init__(self, *args, **kwargs): } body_format = serializers.SerializerMethodField() - # many related from model + # Many related from model comments = relations.ResourceRelatedField( - source='comment_set', many=True, read_only=True) - # many related from serializer + source='comment_set', many=True, read_only=True) + # Many related from serializer suggested = relations.SerializerMethodResourceRelatedField( - source='get_suggested', model=Entry, many=True, read_only=True) - # single related from serializer + source='get_suggested', model=models.Entry, many=True, read_only=True) + # Single related from serializer featured = relations.SerializerMethodResourceRelatedField( - source='get_featured', model=Entry, read_only=True) + source='get_featured', model=models.Entry, read_only=True) def get_suggested(self, obj): - return Entry.objects.exclude(pk=obj.pk) + return models.Entry.objects.exclude(pk=obj.pk) def get_featured(self, obj): - return Entry.objects.exclude(pk=obj.pk).first() + return models.Entry.objects.exclude(pk=obj.pk).first() def get_body_format(self, obj): return 'text' class Meta: - model = Entry + model = models.Entry fields = ('blog', 'headline', 'body_text', 'pub_date', 'mod_date', 'authors', 'comments', 'featured', 'suggested',) meta_fields = ('body_format',) @@ -67,7 +67,7 @@ class Meta: class AuthorBioSerializer(serializers.ModelSerializer): class Meta: - model = AuthorBio + model = models.AuthorBio fields = ('author', 'body',) @@ -77,7 +77,7 @@ class AuthorSerializer(serializers.ModelSerializer): } class Meta: - model = Author + model = models.Author fields = ('name', 'email', 'bio') @@ -88,6 +88,41 @@ class CommentSerializer(serializers.ModelSerializer): } class Meta: - model = Comment + model = models.Comment exclude = ('created_at', 'modified_at',) # fields = ('entry', 'body', 'author',) + + +class ArtProjectSerializer(serializers.ModelSerializer): + class Meta: + model = models.ArtProject + exclude = ('polymorphic_ctype',) + + +class ResearchProjectSerializer(serializers.ModelSerializer): + class Meta: + model = models.ResearchProject + exclude = ('polymorphic_ctype',) + + +class ProjectSerializer(serializers.PolymorphicModelSerializer): + polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer] + + class Meta: + model = models.Project + exclude = ('polymorphic_ctype',) + + +class CompanySerializer(serializers.ModelSerializer): + current_project = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=models.Project.objects.all()) + future_projects = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=models.Project.objects.all(), many=True) + + included_serializers = { + 'current_project': ProjectSerializer, + 'future_projects': ProjectSerializer, + } + + class Meta: + model = models.Company diff --git a/example/settings/dev.py b/example/settings/dev.py index b4b435ca..5a59ba90 100644 --- a/example/settings/dev.py +++ b/example/settings/dev.py @@ -23,6 +23,7 @@ 'django.contrib.auth', 'django.contrib.admin', 'rest_framework', + 'polymorphic', 'example', ] diff --git a/example/settings/test.py b/example/settings/test.py index 5bb3f45d..d0157138 100644 --- a/example/settings/test.py +++ b/example/settings/test.py @@ -15,3 +15,6 @@ REST_FRAMEWORK.update({ 'PAGE_SIZE': 1, }) +JSON_API_POLYMORPHIC_ANCESTORS = ( + 'polymorphic.models.PolymorphicModel', +) diff --git a/example/tests/conftest.py b/example/tests/conftest.py index 8a96cfdb..cb059f81 100644 --- a/example/tests/conftest.py +++ b/example/tests/conftest.py @@ -1,13 +1,16 @@ import pytest from pytest_factoryboy import register -from example.factories import BlogFactory, AuthorFactory, AuthorBioFactory, EntryFactory, CommentFactory +from example import factories -register(BlogFactory) -register(AuthorFactory) -register(AuthorBioFactory) -register(EntryFactory) -register(CommentFactory) +register(factories.BlogFactory) +register(factories.AuthorFactory) +register(factories.AuthorBioFactory) +register(factories.EntryFactory) +register(factories.CommentFactory) +register(factories.ArtProjectFactory) +register(factories.ResearchProjectFactory) +register(factories.CompanyFactory) @pytest.fixture @@ -29,3 +32,13 @@ def multiple_entries(blog_factory, author_factory, entry_factory, comment_factor comment_factory(entry=entries[1]) return entries + +@pytest.fixture +def single_company(art_project_factory, research_project_factory, company_factory): + company = company_factory(future_projects=(research_project_factory(), art_project_factory())) + return company + + +@pytest.fixture +def single_art_project(art_project_factory): + return art_project_factory() diff --git a/example/tests/integration/test_polymorphism.py b/example/tests/integration/test_polymorphism.py new file mode 100644 index 00000000..5b7fbb7b --- /dev/null +++ b/example/tests/integration/test_polymorphism.py @@ -0,0 +1,136 @@ +import pytest +import random +import json +from django.core.urlresolvers import reverse + +from example.tests.utils import load_json + +pytestmark = pytest.mark.django_db + + +def test_polymorphism_on_detail(single_art_project, client): + response = client.get(reverse("project-detail", kwargs={'pk': single_art_project.pk})) + content = load_json(response.content) + assert content["data"]["type"] == "artProjects" + + +def test_polymorphism_on_detail_relations(single_company, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk})) + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + assert [rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]] == [ + "researchProjects", "artProjects"] + + +def test_polymorphism_on_included_relations(single_company, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk}) + + '?include=current_project,future_projects') + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + assert [rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]] == [ + "researchProjects", "artProjects"] + assert [x.get('type') for x in content.get('included')] == [ + 'artProjects', 'artProjects', 'researchProjects'], 'Detail included types are incorrect' + # Ensure that the child fields are present. + assert content.get('included')[0].get('attributes').get('artist') is not None + assert content.get('included')[1].get('attributes').get('artist') is not None + assert content.get('included')[2].get('attributes').get('supervisor') is not None + + +def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, client): + url = reverse("project-detail", kwargs={'pk': single_art_project.pk}) + response = client.get(url) + content = load_json(response.content) + test_topic = 'test-{}'.format(random.randint(0, 999999)) + test_artist = 'test-{}'.format(random.randint(0, 999999)) + content['data']['attributes']['topic'] = test_topic + content['data']['attributes']['artist'] = test_artist + response = client.patch(url, data=json.dumps(content), content_type='application/vnd.api+json') + new_content = load_json(response.content) + assert new_content["data"]["type"] == "artProjects" + assert new_content['data']['attributes']['topic'] == test_topic + assert new_content['data']['attributes']['artist'] == test_artist + + +def test_polymorphism_on_polymorphic_model_list_post(client): + test_topic = 'New test topic {}'.format(random.randint(0, 999999)) + test_artist = 'test-{}'.format(random.randint(0, 999999)) + url = reverse('project-list') + data = { + 'data': { + 'type': 'artProjects', + 'attributes': { + 'topic': test_topic, + 'artist': test_artist + } + } + } + response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json') + content = load_json(response.content) + assert content['data']['id'] is not None + assert content["data"]["type"] == "artProjects" + assert content['data']['attributes']['topic'] == test_topic + assert content['data']['attributes']['artist'] == test_artist + + +def test_invalid_type_on_polymorphic_model(client): + test_topic = 'New test topic {}'.format(random.randint(0, 999999)) + test_artist = 'test-{}'.format(random.randint(0, 999999)) + url = reverse('project-list') + data = { + 'data': { + 'type': 'invalidProjects', + 'attributes': { + 'topic': test_topic, + 'artist': test_artist + } + } + } + response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json') + assert response.status_code == 409 + content = load_json(response.content) + assert len(content["errors"]) is 1 + assert content["errors"][0]["status"] == "409" + assert content["errors"][0]["detail"] == \ + "The resource object's type (invalidProjects) is not the type that constitute the " \ + "collection represented by the endpoint (one of [researchProjects, artProjects])." + + +def test_polymorphism_relations_update(single_company, research_project_factory, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk})) + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + + research_project = research_project_factory() + content["data"]["relationships"]["currentProject"]["data"] = { + "type": "researchProjects", + "id": research_project.pk + } + response = client.put(reverse("company-detail", kwargs={'pk': single_company.pk}), + data=json.dumps(content), content_type='application/vnd.api+json') + assert response.status_code == 200 + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "researchProjects" + assert int(content["data"]["relationships"]["currentProject"]["data"]["id"]) == \ + research_project.pk + + +def test_invalid_type_on_polymorphic_relation(single_company, research_project_factory, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk})) + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + + research_project = research_project_factory() + content["data"]["relationships"]["currentProject"]["data"] = { + "type": "invalidProjects", + "id": research_project.pk + } + response = client.put(reverse("company-detail", kwargs={'pk': single_company.pk}), + data=json.dumps(content), content_type='application/vnd.api+json') + assert response.status_code == 409 + content = load_json(response.content) + assert len(content["errors"]) is 1 + assert content["errors"][0]["status"] == "409" + assert content["errors"][0]["detail"] == \ + "Incorrect relation type. Expected one of [researchProjects, artProjects], " \ + "received invalidProjects." diff --git a/example/urls.py b/example/urls.py index f48135c7..4443960f 100644 --- a/example/urls.py +++ b/example/urls.py @@ -1,7 +1,8 @@ from django.conf.urls import include, url from rest_framework import routers -from example.views import BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet +from example.views import ( + BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, CompanyViewset, ProjectViewset) router = routers.DefaultRouter(trailing_slash=False) @@ -9,6 +10,8 @@ router.register(r'entries', EntryViewSet) router.register(r'authors', AuthorViewSet) router.register(r'comments', CommentViewSet) +router.register(r'companies', CompanyViewset) +router.register(r'projects', ProjectViewset) urlpatterns = [ url(r'^', include(router.urls)), diff --git a/example/urls_test.py b/example/urls_test.py index 0f8ed73b..21f29fd1 100644 --- a/example/urls_test.py +++ b/example/urls_test.py @@ -1,8 +1,9 @@ from django.conf.urls import include, url from rest_framework import routers -from example.views import BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, EntryRelationshipView, BlogRelationshipView, \ - CommentRelationshipView, AuthorRelationshipView +from example.views import ( + BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, CompanyViewset, ProjectViewset, + EntryRelationshipView, BlogRelationshipView, CommentRelationshipView, AuthorRelationshipView) from .api.resources.identity import Identity, GenericIdentity router = routers.DefaultRouter(trailing_slash=False) @@ -11,6 +12,8 @@ router.register(r'entries', EntryViewSet) router.register(r'authors', AuthorViewSet) router.register(r'comments', CommentViewSet) +router.register(r'companies', CompanyViewset) +router.register(r'projects', ProjectViewset) # for the old tests router.register(r'identities', Identity) @@ -36,4 +39,3 @@ AuthorRelationshipView.as_view(), name='author-relationships'), ] - diff --git a/example/views.py b/example/views.py index 988cda66..e32db8c0 100644 --- a/example/views.py +++ b/example/views.py @@ -6,9 +6,10 @@ import rest_framework_json_api.parsers import rest_framework_json_api.renderers from rest_framework_json_api.views import RelationshipView -from example.models import Blog, Entry, Author, Comment +from example.models import Blog, Entry, Author, Comment, Company, Project from example.serializers import ( - BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer) + BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer, CompanySerializer, + ProjectSerializer) from rest_framework_json_api.utils import format_drf_errors @@ -72,6 +73,16 @@ class CommentViewSet(viewsets.ModelViewSet): serializer_class = CommentSerializer +class CompanyViewset(viewsets.ModelViewSet): + queryset = Company.objects.all() + serializer_class = CompanySerializer + + +class ProjectViewset(viewsets.ModelViewSet): + queryset = Project.objects.all() + serializer_class = ProjectSerializer + + class EntryRelationshipView(RelationshipView): queryset = Entry.objects.all() diff --git a/requirements-development.txt b/requirements-development.txt index 6aa243bd..b5e25321 100644 --- a/requirements-development.txt +++ b/requirements-development.txt @@ -3,4 +3,5 @@ pytest==2.8.2 pytest-django pytest-factoryboy fake-factory +django-polymorphic tox diff --git a/rest_framework_json_api/parsers.py b/rest_framework_json_api/parsers.py index 30b9ad0e..f1cd6abe 100644 --- a/rest_framework_json_api/parsers.py +++ b/rest_framework_json_api/parsers.py @@ -1,6 +1,7 @@ """ Parsers """ +import six from rest_framework import parsers from rest_framework.exceptions import ParseError @@ -29,7 +30,8 @@ class JSONParser(parsers.JSONParser): @staticmethod def parse_attributes(data): - return utils.format_keys(data.get('attributes'), 'underscore') if data.get('attributes') else dict() + return utils.format_keys( + data.get('attributes'), 'underscore') if data.get('attributes') else dict() @staticmethod def parse_relationships(data): @@ -50,39 +52,52 @@ def parse(self, stream, media_type=None, parser_context=None): """ Parses the incoming bytestream as JSON and returns the resulting data """ - result = super(JSONParser, self).parse(stream, media_type=media_type, parser_context=parser_context) + result = super(JSONParser, self).parse( + stream, media_type=media_type, parser_context=parser_context) data = result.get('data') if data: from rest_framework_json_api.views import RelationshipView if isinstance(parser_context['view'], RelationshipView): - # We skip parsing the object as JSONAPI Resource Identifier Object and not a regular Resource Object + # We skip parsing the object as JSONAPI Resource Identifier Object is not a + # regular Resource Object if isinstance(data, list): for resource_identifier_object in data: - if not (resource_identifier_object.get('id') and resource_identifier_object.get('type')): - raise ParseError( - 'Received data contains one or more malformed JSONAPI Resource Identifier Object(s)' - ) + if not (resource_identifier_object.get('id') and + resource_identifier_object.get('type')): + raise ParseError('Received data contains one or more malformed ' + 'JSONAPI Resource Identifier Object(s)') elif not (data.get('id') and data.get('type')): - raise ParseError('Received data is not a valid JSONAPI Resource Identifier Object') + raise ParseError('Received data is not a valid ' + 'JSONAPI Resource Identifier Object') return data request = parser_context.get('request') # Check for inconsistencies - resource_name = utils.get_resource_name(parser_context) - if data.get('type') != resource_name and request.method in ('PUT', 'POST', 'PATCH'): - raise exceptions.Conflict( - "The resource object's type ({data_type}) is not the type " - "that constitute the collection represented by the endpoint ({resource_type}).".format( - data_type=data.get('type'), - resource_type=resource_name - ) - ) + if request.method in ('PUT', 'POST', 'PATCH'): + resource_name = utils.get_resource_name( + parser_context, expand_polymorphic_types=True) + if isinstance(resource_name, six.string_types): + if data.get('type') != resource_name: + raise exceptions.Conflict( + "The resource object's type ({data_type}) is not the type that " + "constitute the collection represented by the endpoint " + "({resource_type}).".format( + data_type=data.get('type'), + resource_type=resource_name)) + else: + if data.get('type') not in resource_name: + raise exceptions.Conflict( + "The resource object's type ({data_type}) is not the type that " + "constitute the collection represented by the endpoint " + "(one of [{resource_types}]).".format( + data_type=data.get('type'), + resource_types=", ".join(resource_name))) # Construct the return data - parsed_data = {'id': data.get('id')} + parsed_data = {'id': data.get('id'), 'type': data.get('type')} parsed_data.update(self.parse_attributes(data)) parsed_data.update(self.parse_relationships(data)) return parsed_data diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py index 0e6594d5..471217ea 100644 --- a/rest_framework_json_api/relations.py +++ b/rest_framework_json_api/relations.py @@ -6,12 +6,13 @@ from django.db.models.query import QuerySet from rest_framework_json_api.exceptions import Conflict -from rest_framework_json_api.utils import Hyperlink, \ +from rest_framework_json_api.utils import POLYMORPHIC_ANCESTORS, Hyperlink, \ get_resource_type_from_queryset, get_resource_type_from_instance, \ get_included_serializers, get_resource_type_from_serializer class ResourceRelatedField(PrimaryKeyRelatedField): + _skip_polymorphic_optimization = True self_link_view_name = None related_link_view_name = None related_link_lookup_field = 'pk' @@ -47,6 +48,12 @@ def __init__(self, self_link_view_name=None, related_link_view_name=None, **kwar super(ResourceRelatedField, self).__init__(**kwargs) + # Determine if relation is polymorphic + self.is_polymorphic = False + model = model or getattr(self.get_queryset(), 'model', None) + if model and issubclass(model, POLYMORPHIC_ANCESTORS): + self.is_polymorphic = True + def use_pk_only_optimization(self): # We need the real object to determine its type... return False @@ -129,7 +136,8 @@ def to_internal_value(self, data): self.fail('missing_id') if data['type'] != expected_relation_type: - self.conflict('incorrect_relation_type', relation_type=expected_relation_type, received_type=data['type']) + self.conflict('incorrect_relation_type', relation_type=expected_relation_type, + received_type=data['type']) return super(ResourceRelatedField, self).to_internal_value(data['id']) @@ -144,7 +152,8 @@ def to_representation(self, value): resource_type = None root = getattr(self.parent, 'parent', self.parent) field_name = self.field_name if self.field_name else self.parent.field_name - if getattr(root, 'included_serializers', None) is not None: + if getattr(root, 'included_serializers', None) is not None and \ + self._skip_polymorphic_optimization: includes = get_included_serializers(root) if field_name in includes.keys(): resource_type = get_resource_type_from_serializer(includes[field_name]) @@ -169,6 +178,42 @@ def choices(self): ]) +class PolymorphicResourceRelatedField(ResourceRelatedField): + + _skip_polymorphic_optimization = False + default_error_messages = dict(ResourceRelatedField.default_error_messages, **{ + 'incorrect_relation_type': _('Incorrect relation type. Expected one of [{relation_type}], ' + 'received {received_type}.'), + }) + + def __init__(self, polymorphic_serializer, *args, **kwargs): + self.polymorphic_serializer = polymorphic_serializer + super(PolymorphicResourceRelatedField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + if isinstance(data, six.text_type): + try: + data = json.loads(data) + except ValueError: + # show a useful error if they send a `pk` instead of resource object + self.fail('incorrect_type', data_type=type(data).__name__) + if not isinstance(data, dict): + self.fail('incorrect_type', data_type=type(data).__name__) + + if 'type' not in data: + self.fail('missing_type') + + if 'id' not in data: + self.fail('missing_id') + + expected_relation_types = self.polymorphic_serializer.get_polymorphic_types() + + if data['type'] not in expected_relation_types: + self.conflict('incorrect_relation_type', relation_type=", ".join( + expected_relation_types), received_type=data['type']) + + return super(ResourceRelatedField, self).to_internal_value(data['id']) + class SerializerMethodResourceRelatedField(ResourceRelatedField): """ diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 1c66c927..16531fb0 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -289,8 +289,6 @@ def extract_included(fields, resource, resource_instance, included_resources): relation_type = utils.get_resource_type_from_serializer(serializer) relation_queryset = list(relation_instance_or_manager.all()) - # Get the serializer fields - serializer_fields = utils.get_serializer_fields(serializer) if serializer_data: for position in range(len(serializer_data)): serializer_resource = serializer_data[position] @@ -299,6 +297,7 @@ def extract_included(fields, resource, resource_instance, included_resources): relation_type or utils.get_resource_type_from_instance(nested_resource_instance) ) + serializer_fields = utils.get_serializer_fields(serializer.__class__(nested_resource_instance, context=serializer.context)) included_data.append( JSONRenderer.build_json_resource_obj( serializer_fields, serializer_resource, nested_resource_instance, resource_type @@ -360,6 +359,9 @@ def extract_root_meta(serializer, resource): @staticmethod def build_json_resource_obj(fields, resource, resource_instance, resource_name): + # Determine type from the instance if the underlying model is polymorphic + if isinstance(resource_instance, utils.POLYMORPHIC_ANCESTORS): + resource_name = utils.get_resource_type_from_instance(resource_instance) resource_data = [ ('type', resource_name), ('id', encoding.force_text(resource_instance.pk) if resource_instance else None), diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index 953c4437..a5457e1c 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -1,8 +1,11 @@ +from django.db.models.query import QuerySet from django.utils.translation import ugettext_lazy as _ +from django.utils import six from rest_framework.exceptions import ParseError from rest_framework.serializers import * from rest_framework_json_api.relations import ResourceRelatedField +from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.utils import ( get_resource_type_from_model, get_resource_type_from_instance, get_resource_type_from_serializer, get_included_serializers) @@ -10,7 +13,8 @@ class ResourceIdentifierObjectSerializer(BaseSerializer): default_error_messages = { - 'incorrect_model_type': _('Incorrect model type. Expected {model_type}, received {received_type}.'), + 'incorrect_model_type': _('Incorrect model type. Expected {model_type}, ' + 'received {received_type}.'), 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'), 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), } @@ -20,7 +24,8 @@ class ResourceIdentifierObjectSerializer(BaseSerializer): def __init__(self, *args, **kwargs): self.model_class = kwargs.pop('model_class', self.model_class) if 'instance' not in kwargs and not self.model_class: - raise RuntimeError('ResourceIdentifierObjectsSerializer must be initialized with a model class.') + raise RuntimeError( + 'ResourceIdentifierObjectsSerializer must be initialized with a model class.') super(ResourceIdentifierObjectSerializer, self).__init__(*args, **kwargs) def to_representation(self, instance): @@ -31,7 +36,8 @@ def to_representation(self, instance): def to_internal_value(self, data): if data['type'] != get_resource_type_from_model(self.model_class): - self.fail('incorrect_model_type', model_type=self.model_class, received_type=data['type']) + self.fail( + 'incorrect_model_type', model_type=self.model_class, received_type=data['type']) pk = data['id'] try: return self.model_class.objects.get(pk=pk) @@ -47,15 +53,18 @@ def __init__(self, *args, **kwargs): request = context.get('request') if context else None if request: - sparse_fieldset_query_param = 'fields[{}]'.format(get_resource_type_from_serializer(self)) + sparse_fieldset_query_param = 'fields[{}]'.format( + get_resource_type_from_serializer(self)) try: - param_name = next(key for key in request.query_params if sparse_fieldset_query_param in key) + param_name = next( + key for key in request.query_params if sparse_fieldset_query_param in key) except StopIteration: pass else: fieldset = request.query_params.get(param_name).split(',') - # iterate over a *copy* of self.fields' underlying OrderedDict, because we may modify the - # original during the iteration. self.fields is a `rest_framework.utils.serializer_helpers.BindingDict` + # Iterate over a *copy* of self.fields' underlying OrderedDict, because we may + # modify the original during the iteration. + # self.fields is a `rest_framework.utils.serializer_helpers.BindingDict` for field_name, field in self.fields.fields.copy().items(): if field_name == api_settings.URL_FIELD_NAME: # leave self link there continue @@ -101,7 +110,8 @@ def validate_path(serializer_class, field_path, path): super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs) -class HyperlinkedModelSerializer(IncludedResourcesValidationMixin, SparseFieldsetsMixin, HyperlinkedModelSerializer): +class HyperlinkedModelSerializer(IncludedResourcesValidationMixin, SparseFieldsetsMixin, + HyperlinkedModelSerializer): """ A type of `ModelSerializer` that uses hyperlinked relationships instead of primary key relationships. Specifically: @@ -152,3 +162,134 @@ def get_field_names(self, declared_fields, info): declared[field_name] = field fields = super(ModelSerializer, self).get_field_names(declared, info) return list(fields) + list(getattr(self.Meta, 'meta_fields', list())) + + +class PolymorphicSerializerMetaclass(SerializerMetaclass): + """ + This metaclass ensures that the `polymorphic_serializers` is correctly defined on a + `PolymorphicSerializer` class and make a cache of model/serializer/type mappings. + """ + + def __new__(cls, name, bases, attrs): + new_class = super(PolymorphicSerializerMetaclass, cls).__new__(cls, name, bases, attrs) + + # Ensure initialization is only performed for subclasses of PolymorphicModelSerializer + # (excluding PolymorphicModelSerializer class itself). + parents = [b for b in bases if isinstance(b, PolymorphicSerializerMetaclass)] + if not parents: + return new_class + + polymorphic_serializers = getattr(new_class, 'polymorphic_serializers', None) + if not polymorphic_serializers: + raise NotImplementedError( + "A PolymorphicModelSerializer must define a `polymorphic_serializers` attribute.") + serializer_to_model = { + serializer: serializer.Meta.model for serializer in polymorphic_serializers} + model_to_serializer = { + serializer.Meta.model: serializer for serializer in polymorphic_serializers} + type_to_serializer = { + get_resource_type_from_serializer(serializer): serializer for + serializer in polymorphic_serializers} + setattr(new_class, '_poly_serializer_model_map', serializer_to_model) + setattr(new_class, '_poly_model_serializer_map', model_to_serializer) + setattr(new_class, '_poly_type_serializer_map', type_to_serializer) + return new_class + + +@six.add_metaclass(PolymorphicSerializerMetaclass) +class PolymorphicModelSerializer(ModelSerializer): + """ + A serializer for polymorphic models. + Useful for "lazy" parent models. Leaves should be represented with a regular serializer. + """ + def get_fields(self): + """ + Return an exhaustive list of the polymorphic serializer fields. + """ + if self.instance is not None: + if not isinstance(self.instance, QuerySet): + serializer_class = self.get_polymorphic_serializer_for_instance(self.instance) + return serializer_class(self.instance, context=self.context).get_fields() + else: + raise Exception("Cannot get fields from a polymorphic serializer given a queryset") + return super(PolymorphicModelSerializer, self).get_fields() + + @classmethod + def get_polymorphic_serializer_for_instance(cls, instance): + """ + Return the polymorphic serializer associated with the given instance/model. + Raise `NotImplementedError` if no serializer is found for the given model. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + try: + return cls._poly_model_serializer_map[instance._meta.model] + except KeyError: + raise NotImplementedError( + "No polymorphic serializer has been found for model {}".format( + instance._meta.model.__name__)) + + @classmethod + def get_polymorphic_model_for_serializer(cls, serializer): + """ + Return the polymorphic model associated with the given serializer. + Raise `NotImplementedError` if no model is found for the given serializer. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + try: + return cls._poly_serializer_model_map[serializer] + except KeyError: + raise NotImplementedError( + "No polymorphic model has been found for serializer {}".format(serializer.__name__)) + + @classmethod + def get_polymorphic_serializer_for_type(cls, obj_type): + """ + Return the polymorphic serializer associated with the given type. + Raise `NotImplementedError` if no serializer is found for the given type. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + try: + return cls._poly_type_serializer_map[obj_type] + except KeyError: + raise NotImplementedError( + "No polymorphic serializer has been found for type {}".format(obj_type)) + + @classmethod + def get_polymorphic_model_for_type(cls, obj_type): + """ + Return the polymorphic model associated with the given type. + Raise `NotImplementedError` if no model is found for the given type. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + return cls.get_polymorphic_model_for_serializer( + cls.get_polymorphic_serializer_for_type(obj_type)) + + @classmethod + def get_polymorphic_types(cls): + """ + Return the list of accepted types. + """ + return cls._poly_type_serializer_map.keys() + + def to_representation(self, instance): + """ + Retrieve the appropriate polymorphic serializer and use this to handle representation. + """ + serializer_class = self.get_polymorphic_serializer_for_instance(instance) + return serializer_class(instance, context=self.context).to_representation(instance) + + def to_internal_value(self, data): + """ + Ensure that the given type is one of the expected polymorphic types, then retrieve the + appropriate polymorphic serializer and use this to handle internal value. + """ + received_type = data.get('type') + expected_types = self.get_polymorphic_types() + if received_type not in expected_types: + raise Conflict( + 'Incorrect relation type. Expected on of [{expected_types}], ' + 'received {received_type}.'.format( + expected_types=', '.join(expected_types), received_type=received_type)) + serializer_class = self.get_polymorphic_serializer_for_type(received_type) + self.__class__ = serializer_class + return serializer_class(data, context=self.context).to_internal_value(data) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 261640c6..95e036ca 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -25,11 +25,17 @@ except ImportError: HyperlinkedRouterField = type(None) +POLYMORPHIC_ANCESTORS = () +for ancestor in getattr(settings, 'JSON_API_POLYMORPHIC_ANCESTORS', ()): + ancestor_class = import_class_from_dotted_path(ancestor) + POLYMORPHIC_ANCESTORS += (ancestor_class,) -def get_resource_name(context): + +def get_resource_name(context, expand_polymorphic_types=False): """ Return the name of a resource. """ + from . import serializers view = context.get('view') # Sanity check to make sure we have a view. @@ -51,7 +57,11 @@ def get_resource_name(context): except AttributeError: try: serializer = view.get_serializer_class() - return get_resource_type_from_serializer(serializer) + if issubclass(serializer, serializers.PolymorphicModelSerializer) and \ + expand_polymorphic_types: + return serializer.get_polymorphic_types() + else: + return get_resource_type_from_serializer(serializer) except AttributeError: try: resource_name = get_resource_type_from_model(view.model) @@ -86,6 +96,7 @@ def get_serializer_fields(serializer): pass return fields + def format_keys(obj, format_type=None): """ Takes either a dict or list and returns it with camelized keys only if @@ -141,12 +152,15 @@ def format_value(value, format_type=None): def format_relation_name(value, format_type=None): - warnings.warn("The 'format_relation_name' function has been renamed 'format_resource_type' and the settings are now 'JSON_API_FORMAT_TYPES' and 'JSON_API_PLURALIZE_TYPES'") + warnings.warn( + "The 'format_relation_name' function has been renamed 'format_resource_type' and " + "the settings are now 'JSON_API_FORMAT_TYPES' and 'JSON_API_PLURALIZE_TYPES'") if format_type is None: format_type = getattr(settings, 'JSON_API_FORMAT_RELATION_KEYS', None) pluralize = getattr(settings, 'JSON_API_PLURALIZE_RELATION_TYPE', None) return format_resource_type(value, format_type, pluralize) + def format_resource_type(value, format_type=None, pluralize=None): if format_type is None: format_type = getattr(settings, 'JSON_API_FORMAT_TYPES', False)