Skip to content

Commit 6d73580

Browse files
committed
Refactor ast to hold data as seperated type
1 parent 9f1a538 commit 6d73580

File tree

111 files changed

+22473
-19206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+22473
-19206
lines changed

ast/asdl_rs.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def visitSum(self, sum, name):
131131
if is_simple(sum):
132132
info.has_userdata = False
133133
else:
134+
for t in sum.types:
135+
self.typeinfo[t.name] = TypeInfo(t.name)
136+
self.add_children(t.name, t.fields)
134137
if len(sum.types) > 1:
135138
info.boxed = True
136139
if sum.attributes:
@@ -205,16 +208,49 @@ def simple_sum(self, sum, name, depth):
205208

206209
def sum_with_constructors(self, sum, name, depth):
207210
typeinfo = self.typeinfo[name]
208-
generics, generics_applied = self.get_generics(name, "U = ()", "U")
209211
enumname = rustname = get_rust_type(name)
210212
# all the attributes right now are for location, so if it has attrs we
211213
# can just wrap it in Located<>
212214
if sum.attributes:
213215
enumname = rustname + "Kind"
216+
217+
for t in sum.types:
218+
if not t.fields:
219+
continue
220+
self.emit_attrs(depth)
221+
self.typeinfo[t] = TypeInfo(t)
222+
t_generics, t_generics_applied = self.get_generics(t.name, "U = ()", "U")
223+
payload_name = f"{rustname}{t.name}"
224+
self.emit(f"pub struct {payload_name}{t_generics} {{", depth)
225+
for f in t.fields:
226+
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
227+
self.emit("}", depth)
228+
self.emit(
229+
textwrap.dedent(
230+
f"""
231+
impl{t_generics_applied} From<{payload_name}{t_generics_applied}> for {enumname}{t_generics_applied} {{
232+
fn from(payload: {payload_name}{t_generics_applied}) -> Self {{
233+
{enumname}::{t.name}(payload)
234+
}}
235+
}}
236+
"""
237+
),
238+
depth,
239+
)
240+
241+
generics, generics_applied = self.get_generics(name, "U = ()", "U")
214242
self.emit_attrs(depth)
215243
self.emit(f"pub enum {enumname}{generics} {{", depth)
216244
for t in sum.types:
217-
self.visit(t, typeinfo, depth + 1)
245+
if t.fields:
246+
t_generics, t_generics_applied = self.get_generics(
247+
t.name, "U = ()", "U"
248+
)
249+
self.emit(
250+
f"{t.name}({rustname}{t.name}{t_generics_applied}),", depth + 1
251+
)
252+
else:
253+
self.emit(f"{t.name},", depth + 1)
218254
self.emit("}", depth)
219255
if sum.attributes:
220256
self.emit(
@@ -238,13 +274,18 @@ def visitField(self, field, parent, vis, depth, constructor=None):
238274
if fieldtype and fieldtype.has_userdata:
239275
typ = f"{typ}<U>"
240276
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
241-
if fieldtype and fieldtype.boxed and (not (parent.product or field.seq) or field.opt):
277+
if (
278+
fieldtype
279+
and fieldtype.boxed
280+
and (not (parent.product or field.seq) or field.opt)
281+
):
242282
typ = f"Box<{typ}>"
243283
if field.opt or (
244284
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
245285
# the expression to be unpacked goes in `values` with a `None` at the corresponding
246286
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
247-
constructor == "Dict" and field.name == "keys"
287+
constructor == "Dict"
288+
and field.name == "keys"
248289
):
249290
typ = f"Option<{typ}>"
250291
if field.seq:
@@ -344,14 +385,21 @@ def visitSum(self, sum, name, depth):
344385
)
345386
if is_located:
346387
self.emit("fold_located(folder, node, |folder, node| {", depth)
347-
enumname += "Kind"
388+
rustname = enumname + "Kind"
389+
else:
390+
rustname = enumname
348391
self.emit("match node {", depth + 1)
349392
for cons in sum.types:
350-
fields_pattern = self.make_pattern(cons.fields)
393+
fields_pattern = self.make_pattern(
394+
enumname, rustname, cons.name, cons.fields
395+
)
351396
self.emit(
352-
f"{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth + 2
397+
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
398+
depth + 2,
399+
)
400+
self.gen_construction(
401+
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
353402
)
354-
self.gen_construction(f"{enumname}::{cons.name}", cons.fields, depth + 3)
355403
self.emit("}", depth + 2)
356404
self.emit("}", depth + 1)
357405
if is_located:
@@ -381,23 +429,33 @@ def visitProduct(self, product, name, depth):
381429
)
382430
if is_located:
383431
self.emit("fold_located(folder, node, |folder, node| {", depth)
384-
structname += "Data"
385-
fields_pattern = self.make_pattern(product.fields)
386-
self.emit(f"let {structname} {{ {fields_pattern} }} = node;", depth + 1)
387-
self.gen_construction(structname, product.fields, depth + 1)
432+
rustname = structname + "Data"
433+
else:
434+
rustname = structname
435+
fields_pattern = self.make_pattern(rustname, structname, None, product.fields)
436+
self.emit(f"let {rustname} {{ {fields_pattern[1]} }} = node;", depth + 1)
437+
self.gen_construction(rustname, product.fields, "", depth + 1)
388438
if is_located:
389439
self.emit("})", depth)
390440
self.emit("}", depth)
391441

392-
def make_pattern(self, fields):
393-
return ",".join(rust_field(f.name) for f in fields)
442+
def make_pattern(self, rustname, pyname, fieldname, fields):
443+
if fields:
444+
header = f"{pyname}::{fieldname}({rustname}{fieldname}"
445+
footer = ")"
446+
else:
447+
header = f"{pyname}::{fieldname}"
448+
footer = ""
394449

395-
def gen_construction(self, cons_path, fields, depth):
396-
self.emit(f"Ok({cons_path} {{", depth)
450+
body = ",".join(rust_field(f.name) for f in fields)
451+
return header, body, footer
452+
453+
def gen_construction(self, header, fields, footer, depth):
454+
self.emit(f"Ok({header} {{", depth)
397455
for field in fields:
398456
name = rust_field(field.name)
399457
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
400-
self.emit("})", depth)
458+
self.emit(f"}}{footer})", depth)
401459

402460

403461
class FoldModuleVisitor(TypeInfoEmitVisitor):
@@ -514,33 +572,36 @@ def visitType(self, type, depth=0):
514572
self.visit(type.value, type.name, depth)
515573

516574
def visitSum(self, sum, name, depth):
517-
enumname = get_rust_type(name)
575+
rustname = enumname = get_rust_type(name)
518576
if sum.attributes:
519-
enumname += "Kind"
577+
rustname = enumname + "Kind"
520578

521-
self.emit(f"impl NamedNode for ast::{enumname} {{", depth)
579+
self.emit(f"impl NamedNode for ast::{rustname} {{", depth)
522580
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
523581
self.emit("}", depth)
524-
self.emit(f"impl Node for ast::{enumname} {{", depth)
582+
self.emit(f"impl Node for ast::{rustname} {{", depth)
525583
self.emit(
526584
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
527585
)
528586
self.emit("match self {", depth + 2)
529587
for variant in sum.types:
530-
self.constructor_to_object(variant, enumname, depth + 3)
588+
self.constructor_to_object(variant, enumname, rustname, depth + 3)
531589
self.emit("}", depth + 2)
532590
self.emit("}", depth + 1)
533591
self.emit(
534592
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
535593
depth + 1,
536594
)
537-
self.gen_sum_fromobj(sum, name, enumname, depth + 2)
595+
self.gen_sum_fromobj(sum, name, enumname, rustname, depth + 2)
538596
self.emit("}", depth + 1)
539597
self.emit("}", depth)
540598

541-
def constructor_to_object(self, cons, enumname, depth):
542-
fields_pattern = self.make_pattern(cons.fields)
543-
self.emit(f"ast::{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth)
599+
def constructor_to_object(self, cons, enumname, rustname, depth):
600+
self.emit(f"ast::{rustname}::{cons.name}", depth)
601+
if cons.fields:
602+
fields_pattern = self.make_pattern(cons.fields)
603+
self.emit(f"( ast::{enumname}{cons.name} {{ {fields_pattern} }} )", depth)
604+
self.emit(" => {", depth)
544605
self.make_node(cons.name, cons.fields, depth + 1)
545606
self.emit("}", depth)
546607

@@ -586,15 +647,20 @@ def make_node(self, variant, fields, depth):
586647
def make_pattern(self, fields):
587648
return ",".join(rust_field(f.name) for f in fields)
588649

589-
def gen_sum_fromobj(self, sum, sumname, enumname, depth):
650+
def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
590651
if sum.attributes:
591652
self.extract_location(sumname, depth)
592653

593654
self.emit("let _cls = _object.class();", depth)
594655
self.emit("Ok(", depth)
595656
for cons in sum.types:
596657
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
597-
self.gen_construction(f"{enumname}::{cons.name}", cons, sumname, depth + 1)
658+
if cons.fields:
659+
self.emit(f"ast::{rustname}::{cons.name} (ast::{enumname}{cons.name} {{", depth + 1)
660+
self.gen_construction_fields(cons, sumname, depth + 1)
661+
self.emit("})", depth + 1)
662+
else:
663+
self.emit(f"ast::{rustname}::{cons.name}", depth + 1)
598664
self.emit("} else", depth)
599665

600666
self.emit("{", depth)
@@ -610,13 +676,16 @@ def gen_product_fromobj(self, product, prodname, structname, depth):
610676
self.gen_construction(structname, product, prodname, depth + 1)
611677
self.emit(")", depth)
612678

613-
def gen_construction(self, cons_path, cons, name, depth):
614-
self.emit(f"ast::{cons_path} {{", depth)
679+
def gen_construction_fields(self, cons, name, depth):
615680
for field in cons.fields:
616681
self.emit(
617682
f"{rust_field(field.name)}: {self.decode_field(field, name)},",
618683
depth + 1,
619684
)
685+
686+
def gen_construction(self, cons_path, cons, name, depth):
687+
self.emit(f"ast::{cons_path} {{", depth)
688+
self.gen_construction_fields(cons, name, depth + 1)
620689
self.emit("}", depth)
621690

622691
def extract_location(self, typename, depth):

0 commit comments

Comments
 (0)