diff --git a/example/tests/unit/test_renderer_class_methods.py b/example/tests/unit/test_renderer_class_methods.py index 61208a32..fc97dee0 100644 --- a/example/tests/unit/test_renderer_class_methods.py +++ b/example/tests/unit/test_renderer_class_methods.py @@ -37,6 +37,45 @@ def test_build_json_resource_obj(): assert JSONRenderer.build_json_resource_obj( serializer.fields, resource, resource_instance, 'user') == output +def test_can_override_methods(): + """ + Make sure extract_attributes and extract_relationships can be overriden. + """ + resource = { + 'pk': 1, + 'username': 'Alice', + } + + serializer = ResourceSerializer(data={'username': 'Alice'}) + serializer.is_valid() + resource_instance = serializer.save() + + output = { + 'type': 'user', + 'id': '1', + 'attributes': { + 'username': 'Alice' + }, + } + + class CustomRenderer(JSONRenderer): + extract_attributes_was_overriden = False + extract_relationships_was_overriden = False + + @classmethod + def extract_attributes(cls, fields, resource): + cls.extract_attributes_was_overriden = True + return super(CustomRenderer, cls).extract_attributes(fields, resource) + + @classmethod + def extract_relationships(cls, fields, resource, resource_instance): + cls.extract_relationships_was_overriden = True + return super(CustomRenderer, cls).extract_relationships(fields, resource, resource_instance) + + assert CustomRenderer.build_json_resource_obj( + serializer.fields, resource, resource_instance, 'user') == output + assert CustomRenderer.extract_attributes_was_overriden + assert CustomRenderer.extract_relationships_was_overriden def test_extract_attributes(): fields = { diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 04a9aaff..098c6700 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -38,8 +38,8 @@ class JSONRenderer(renderers.JSONRenderer): media_type = 'application/vnd.api+json' format = 'vnd.api+json' - @staticmethod - def extract_attributes(fields, resource): + @classmethod + def extract_attributes(cls, fields, resource): data = OrderedDict() for field_name, field in six.iteritems(fields): # ID is always provided in the root of JSON API so remove it from attributes @@ -67,8 +67,8 @@ def extract_attributes(fields, resource): return utils.format_keys(data) - @staticmethod - def extract_relationships(fields, resource, resource_instance): + @classmethod + def extract_relationships(cls, fields, resource, resource_instance): # Avoid circular deps from rest_framework_json_api.relations import ResourceRelatedField @@ -242,8 +242,8 @@ def extract_relationships(fields, resource, resource_instance): return utils.format_keys(data) - @staticmethod - def extract_included(fields, resource, resource_instance, included_resources): + @classmethod + def extract_included(cls, fields, resource, resource_instance, included_resources): # this function may be called with an empty record (example: Browsable Interface) if not resource_instance: return @@ -322,12 +322,12 @@ def extract_included(fields, resource, resource_instance, included_resources): utils.get_resource_type_from_instance(nested_resource_instance) ) included_data.append( - JSONRenderer.build_json_resource_obj( + cls.build_json_resource_obj( serializer_fields, serializer_resource, nested_resource_instance, resource_type ) ) included_data.extend( - JSONRenderer.extract_included( + cls.extract_included( serializer_fields, serializer_resource, nested_resource_instance, new_included_resources ) ) @@ -340,20 +340,20 @@ def extract_included(fields, resource, resource_instance, included_resources): serializer_fields = utils.get_serializer_fields(field) if serializer_data: included_data.append( - JSONRenderer.build_json_resource_obj( + cls.build_json_resource_obj( serializer_fields, serializer_data, relation_instance, relation_type) ) included_data.extend( - JSONRenderer.extract_included( + cls.extract_included( serializer_fields, serializer_data, relation_instance, new_included_resources ) ) return utils.format_keys(included_data) - @staticmethod - def extract_meta(serializer, resource): + @classmethod + def extract_meta(cls, serializer, resource): if hasattr(serializer, 'child'): meta = getattr(serializer.child, 'Meta', None) else: @@ -366,8 +366,8 @@ def extract_meta(serializer, resource): }) return data - @staticmethod - def extract_root_meta(serializer, resource): + @classmethod + def extract_root_meta(cls, serializer, resource): many = False if hasattr(serializer, 'child'): many = True @@ -380,14 +380,14 @@ def extract_root_meta(serializer, resource): data.update(json_api_meta) return data - @staticmethod - def build_json_resource_obj(fields, resource, resource_instance, resource_name): + @classmethod + def build_json_resource_obj(cls, fields, resource, resource_instance, resource_name): resource_data = [ ('type', resource_name), ('id', encoding.force_text(resource_instance.pk) if resource_instance else None), - ('attributes', JSONRenderer.extract_attributes(fields, resource)), + ('attributes', cls.extract_attributes(fields, resource)), ] - relationships = JSONRenderer.extract_relationships(fields, resource, resource_instance) + relationships = cls.extract_relationships(fields, resource, resource_instance) if relationships: resource_data.append(('relationships', relationships)) # Add 'self' link if field is present and valid