diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index faedb8d2..3c483a0d 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -146,12 +146,20 @@ def convert_column_to_float(type, column, registry=None): @convert_sqlalchemy_type.register(types.Enum) def convert_enum_to_enum(type, column, registry=None): - try: - items = type.enum_class.__members__.items() - except AttributeError: - items = zip(type.enums, type.enums) + enum_type = None + if registry is not None: + enum_type = registry.get_type_for_enum(type.name) + if enum_type is None: + try: + items = type.enum_class.__members__.items() + except AttributeError: + items = zip(type.enums, type.enums) + enum_type = Enum(type.name, items) + if registry is not None: + registry.register_type_for_enum(type.name, enum_type) + return Field( - Enum(type.name, items), + enum_type, description=get_column_doc(column), required=not (is_column_nullable(column)), ) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 460053f2..30bfa712 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -3,6 +3,7 @@ def __init__(self): self._registry = {} self._registry_models = {} self._registry_composites = {} + self._registry_enums = {} def register(self, cls): from .types import SQLAlchemyObjectType @@ -27,6 +28,12 @@ def register_composite_converter(self, composite, converter): def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) + def register_type_for_enum(self, enum_type_name, graphene_enum): + self._registry_enums[enum_type_name] = graphene_enum + + def get_type_for_enum(self, enum_type_name): + return self._registry_enums.get(enum_type_name) + registry = None diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index be7d5cd2..4facd146 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -21,12 +21,14 @@ class Editor(Base): editor_id = Column(Integer(), primary_key=True) name = Column(String(100)) +PetKindEnum = Enum("cat", "dog", name="pet_kind") class Pet(Base): __tablename__ = "pets" id = Column(Integer(), primary_key=True) name = Column(String(30)) - pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) + pet_kind = Column(PetKindEnum, nullable=False) + pet_kind_again = Column(PetKindEnum) reporter_id = Column(Integer(), ForeignKey("reporters.id"))