diff --git a/README.md b/README.md index 91a5349f..38a70cfb 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ class User(SQLAlchemyObjectType): only_fields = ("name",) # exclude specified fields exclude_fields = ("last_name",) + # Rename specified fields + rename_fields = {'name': 'first_name'} class Query(graphene.ObjectType): users = graphene.List(User) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3ba23a8a..9582be6d 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -58,6 +58,7 @@ class Article(Base): __tablename__ = "articles" id = Column(Integer(), primary_key=True) headline = Column(String(100)) + description = Column(String(100)) pub_date = Column(Date()) reporter_id = Column(Integer(), ForeignKey("reporters.id")) diff --git a/graphene_sqlalchemy/tests/test_reflected.py b/graphene_sqlalchemy/tests/test_reflected.py index c8a1a70f..5f20e12e 100644 --- a/graphene_sqlalchemy/tests/test_reflected.py +++ b/graphene_sqlalchemy/tests/test_reflected.py @@ -17,5 +17,17 @@ class Meta: def test_objecttype_registered(): assert issubclass(Reflected, ObjectType) assert Reflected._meta.model == ReflectedEditor - assert list(Reflected._meta.fields.keys()) == ["editor_id", "name"] + assert list(Reflected._meta.fields) == ["editor_id", "name"] + +class ReflectedWithFieldRenamed(SQLAlchemyObjectType): + class Meta: + model = ReflectedEditor + registry = registry + rename_fields = { + "name": "editor_name", + } + + +def test_objecttype_registered_with_rename_fields(): + assert list(ReflectedWithFieldRenamed._meta.fields) == ["editor_id", "editor_name"] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index c11ec351..88289ccb 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -47,7 +47,7 @@ def test_sqlalchemy_interface(): def test_objecttype_registered(): assert issubclass(Character, ObjectType) assert Character._meta.model == Reporter - assert list(Character._meta.fields.keys()) == [ + assert list(Character._meta.fields) == [ "id", "first_name", "last_name", @@ -87,9 +87,10 @@ class Meta: interfaces = (Node,) assert issubclass(Human, ObjectType) - assert list(Human._meta.fields.keys()) == [ + assert list(Human._meta.fields) == [ "id", "headline", + "description", "pub_date", "reporter_id", "reporter", @@ -114,7 +115,7 @@ class Meta: def test_custom_objecttype_registered(): assert issubclass(CustomCharacter, ObjectType) assert CustomCharacter._meta.model == Reporter - assert list(CustomCharacter._meta.fields.keys()) == [ + assert list(CustomCharacter._meta.fields) == [ "id", "first_name", "last_name", @@ -157,7 +158,7 @@ class Meta: def test_objecttype_with_custom_options(): assert issubclass(ReporterWithCustomOptions, ObjectType) assert ReporterWithCustomOptions._meta.model == Reporter - assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ + assert list(ReporterWithCustomOptions._meta.fields) == [ "custom_field", "id", "first_name", @@ -181,3 +182,37 @@ class Meta: resolver, TestConnection, ReporterWithCustomOptions, None, None ) assert result is not None + + +class HumanWithFieldRenamed(SQLAlchemyObjectType): + + publication_timestamp = Int() + + class Meta: + model = Article + registry = registry + interfaces = (Node,) + rename_fields = { + "id": "article_id", + "headline": "title", + "pub_date": "publication_timestamp", + "reporter_id": "journalist_id", + "reporter": "journalist", + } + + +def test_objecttype_with_rename_fields(): + assert issubclass(HumanWithFieldRenamed, ObjectType) + assert HumanWithFieldRenamed._meta.model == Article + assert list(HumanWithFieldRenamed._meta.fields) == [ + "article_id", + "title", + "description", + "publication_timestamp", + "journalist_id", + "journalist", + "id", # Graphene Node ID + ] + replaced_field = HumanWithFieldRenamed._meta.fields["publication_timestamp"] + assert isinstance(replaced_field, Field) + assert replaced_field.type == Int diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e8a05c8f..c256b44b 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -19,8 +19,10 @@ from .utils import get_query, is_mapped_class, is_mapped_instance -def construct_fields(model, registry, only_fields, exclude_fields): +def construct_fields(model, registry, only_fields, exclude_fields, rename_fields): inspected_model = sqlalchemyinspect(model) + if rename_fields is None: + rename_fields = {} fields = OrderedDict() @@ -33,7 +35,8 @@ def construct_fields(model, registry, only_fields, exclude_fields): # in there. Or when we exclude this field in exclude_fields continue converted_column = convert_sqlalchemy_column(column, registry) - fields[name] = converted_column + field_name = rename_fields.get(name, name) + fields[field_name] = converted_column for name, composite in inspected_model.composites.items(): is_not_in_only = only_fields and name not in only_fields @@ -44,7 +47,8 @@ def construct_fields(model, registry, only_fields, exclude_fields): # in there. Or when we exclude this field in exclude_fields continue converted_composite = convert_sqlalchemy_composite(composite, registry) - fields[name] = converted_composite + field_name = rename_fields.get(name, name) + fields[field_name] = converted_composite for hybrid_item in inspected_model.all_orm_descriptors: @@ -61,7 +65,8 @@ def construct_fields(model, registry, only_fields, exclude_fields): continue converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) - fields[name] = converted_hybrid_property + field_name = rename_fields.get(name, name) + fields[field_name] = converted_hybrid_property # Get all the columns for the relationships on the model for relationship in inspected_model.relationships: @@ -74,7 +79,8 @@ def construct_fields(model, registry, only_fields, exclude_fields): continue converted_relationship = convert_sqlalchemy_relationship(relationship, registry) name = relationship.key - fields[name] = converted_relationship + field_name = rename_fields.get(name, name) + fields[field_name] = converted_relationship return fields @@ -95,6 +101,7 @@ def __init_subclass_with_meta__( skip_registry=False, only_fields=(), exclude_fields=(), + rename_fields=None, connection=None, connection_class=None, use_connection=None, @@ -116,7 +123,7 @@ def __init_subclass_with_meta__( ).format(cls.__name__, registry) sqla_fields = yank_fields_from_attrs( - construct_fields(model, registry, only_fields, exclude_fields), _as=Field + construct_fields(model, registry, only_fields, exclude_fields, rename_fields), _as=Field ) if use_connection is None and interfaces: