diff --git a/example/tests/integration/test_includes.py b/example/tests/integration/test_includes.py index 05c59131..8e2a7a2a 100644 --- a/example/tests/integration/test_includes.py +++ b/example/tests/integration/test_includes.py @@ -2,12 +2,19 @@ from django.core.urlresolvers import reverse from example.tests.utils import load_json +import mock pytestmark = pytest.mark.django_db -def test_included_data_on_list(multiple_entries, client): - response = client.get(reverse("entry-list") + '?include=comments&page_size=5') + +@mock.patch('rest_framework_json_api.utils.get_default_included_resources_from_serializer', new=lambda s: ['comments']) +def test_default_included_data_on_list(multiple_entries, client): + return test_included_data_on_list(multiple_entries=multiple_entries, client=client, query='?page_size=5') + + +def test_included_data_on_list(multiple_entries, client, query='?include=comments&page_size=5'): + response = client.get(reverse("entry-list") + query) included = load_json(response.content).get('included') assert len(load_json(response.content)['data']) == len(multiple_entries), 'Incorrect entry count' @@ -18,8 +25,13 @@ def test_included_data_on_list(multiple_entries, client): assert comment_count == expected_comment_count, 'List comment count is incorrect' -def test_included_data_on_detail(single_entry, client): - response = client.get(reverse("entry-detail", kwargs={'pk': single_entry.pk}) + '?include=comments') +@mock.patch('rest_framework_json_api.utils.get_default_included_resources_from_serializer', new=lambda s: ['comments']) +def test_default_included_data_on_detail(single_entry, client): + return test_included_data_on_detail(single_entry=single_entry, client=client, query='') + + +def test_included_data_on_detail(single_entry, client, query='?include=comments'): + response = client.get(reverse("entry-detail", kwargs={'pk': single_entry.pk}) + query) included = load_json(response.content).get('included') assert [x.get('type') for x in included] == ['comments'], 'Detail included types are incorrect' diff --git a/requirements-development.txt b/requirements-development.txt index 6aa243bd..78ccdc91 100644 --- a/requirements-development.txt +++ b/requirements-development.txt @@ -4,3 +4,4 @@ pytest-django pytest-factoryboy fake-factory tox +mock diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 1c66c927..0a11063e 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -415,12 +415,6 @@ def render(self, data, accepted_media_type=None, renderer_context=None): if resource_name == 'errors': return self.render_errors(data, accepted_media_type, renderer_context) - include_resources_param = request.query_params.get('include') if request else None - if include_resources_param: - included_resources = include_resources_param.split(',') - else: - included_resources = list() - json_api_data = data json_api_included = list() # initialize json_api_meta with pagination meta or an empty dict @@ -433,6 +427,13 @@ def render(self, data, accepted_media_type=None, renderer_context=None): serializer = getattr(serializer_data, 'serializer', None) + # Build a list of included resources + include_resources_param = request.query_params.get('include') if request else None + if include_resources_param: + included_resources = include_resources_param.split(',') + else: + included_resources = utils.get_default_included_resources_from_serializer(serializer) + if serializer is not None: # Get the serializer fields diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 261640c6..bd7d013d 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -232,6 +232,13 @@ def get_resource_type_from_serializer(serializer): return get_resource_type_from_model(serializer.Meta.model) +def get_default_included_resources_from_serializer(serializer): + try: + return list(serializer.JSONAPIMeta.included_resources) + except AttributeError: + return [] + + def get_included_serializers(serializer): included_serializers = copy.copy(getattr(serializer, 'included_serializers', dict()))