@@ -131,6 +131,9 @@ def visitSum(self, sum, name):
131
131
if is_simple (sum ):
132
132
info .has_userdata = False
133
133
else :
134
+ for t in sum .types :
135
+ self .typeinfo [t .name ] = TypeInfo (t .name )
136
+ self .add_children (t .name , t .fields )
134
137
if len (sum .types ) > 1 :
135
138
info .boxed = True
136
139
if sum .attributes :
@@ -205,16 +208,49 @@ def simple_sum(self, sum, name, depth):
205
208
206
209
def sum_with_constructors (self , sum , name , depth ):
207
210
typeinfo = self .typeinfo [name ]
208
- generics , generics_applied = self .get_generics (name , "U = ()" , "U" )
209
211
enumname = rustname = get_rust_type (name )
210
212
# all the attributes right now are for location, so if it has attrs we
211
213
# can just wrap it in Located<>
212
214
if sum .attributes :
213
215
enumname = rustname + "Kind"
216
+
217
+ for t in sum .types :
218
+ if not t .fields :
219
+ continue
220
+ self .emit_attrs (depth )
221
+ self .typeinfo [t ] = TypeInfo (t )
222
+ t_generics , t_generics_applied = self .get_generics (t .name , "U = ()" , "U" )
223
+ payload_name = f"{ rustname } { t .name } "
224
+ self .emit (f"pub struct { payload_name } { t_generics } {{" , depth )
225
+ for f in t .fields :
226
+ self .visit (f , typeinfo , "pub " , depth + 1 , t .name )
227
+ self .emit ("}" , depth )
228
+ self .emit (
229
+ textwrap .dedent (
230
+ f"""
231
+ impl{ t_generics_applied } From<{ payload_name } { t_generics_applied } > for { enumname } { t_generics_applied } {{
232
+ fn from(payload: { payload_name } { t_generics_applied } ) -> Self {{
233
+ { enumname } ::{ t .name } (payload)
234
+ }}
235
+ }}
236
+ """
237
+ ),
238
+ depth ,
239
+ )
240
+
241
+ generics , generics_applied = self .get_generics (name , "U = ()" , "U" )
214
242
self .emit_attrs (depth )
215
243
self .emit (f"pub enum { enumname } { generics } {{" , depth )
216
244
for t in sum .types :
217
- self .visit (t , typeinfo , depth + 1 )
245
+ if t .fields :
246
+ t_generics , t_generics_applied = self .get_generics (
247
+ t .name , "U = ()" , "U"
248
+ )
249
+ self .emit (
250
+ f"{ t .name } ({ rustname } { t .name } { t_generics_applied } )," , depth + 1
251
+ )
252
+ else :
253
+ self .emit (f"{ t .name } ," , depth + 1 )
218
254
self .emit ("}" , depth )
219
255
if sum .attributes :
220
256
self .emit (
@@ -238,13 +274,18 @@ def visitField(self, field, parent, vis, depth, constructor=None):
238
274
if fieldtype and fieldtype .has_userdata :
239
275
typ = f"{ typ } <U>"
240
276
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
241
- if fieldtype and fieldtype .boxed and (not (parent .product or field .seq ) or field .opt ):
277
+ if (
278
+ fieldtype
279
+ and fieldtype .boxed
280
+ and (not (parent .product or field .seq ) or field .opt )
281
+ ):
242
282
typ = f"Box<{ typ } >"
243
283
if field .opt or (
244
284
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
245
285
# the expression to be unpacked goes in `values` with a `None` at the corresponding
246
286
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
247
- constructor == "Dict" and field .name == "keys"
287
+ constructor == "Dict"
288
+ and field .name == "keys"
248
289
):
249
290
typ = f"Option<{ typ } >"
250
291
if field .seq :
@@ -344,14 +385,21 @@ def visitSum(self, sum, name, depth):
344
385
)
345
386
if is_located :
346
387
self .emit ("fold_located(folder, node, |folder, node| {" , depth )
347
- enumname += "Kind"
388
+ rustname = enumname + "Kind"
389
+ else :
390
+ rustname = enumname
348
391
self .emit ("match node {" , depth + 1 )
349
392
for cons in sum .types :
350
- fields_pattern = self .make_pattern (cons .fields )
393
+ fields_pattern = self .make_pattern (
394
+ enumname , rustname , cons .name , cons .fields
395
+ )
351
396
self .emit (
352
- f"{ enumname } ::{ cons .name } {{ { fields_pattern } }} => {{" , depth + 2
397
+ f"{ fields_pattern [0 ]} {{ { fields_pattern [1 ]} }} { fields_pattern [2 ]} => {{" ,
398
+ depth + 2 ,
399
+ )
400
+ self .gen_construction (
401
+ fields_pattern [0 ], cons .fields , fields_pattern [2 ], depth + 3
353
402
)
354
- self .gen_construction (f"{ enumname } ::{ cons .name } " , cons .fields , depth + 3 )
355
403
self .emit ("}" , depth + 2 )
356
404
self .emit ("}" , depth + 1 )
357
405
if is_located :
@@ -381,23 +429,33 @@ def visitProduct(self, product, name, depth):
381
429
)
382
430
if is_located :
383
431
self .emit ("fold_located(folder, node, |folder, node| {" , depth )
384
- structname += "Data"
385
- fields_pattern = self .make_pattern (product .fields )
386
- self .emit (f"let { structname } {{ { fields_pattern } }} = node;" , depth + 1 )
387
- self .gen_construction (structname , product .fields , depth + 1 )
432
+ rustname = structname + "Data"
433
+ else :
434
+ rustname = structname
435
+ fields_pattern = self .make_pattern (rustname , structname , None , product .fields )
436
+ self .emit (f"let { rustname } {{ { fields_pattern [1 ]} }} = node;" , depth + 1 )
437
+ self .gen_construction (rustname , product .fields , "" , depth + 1 )
388
438
if is_located :
389
439
self .emit ("})" , depth )
390
440
self .emit ("}" , depth )
391
441
392
- def make_pattern (self , fields ):
393
- return "," .join (rust_field (f .name ) for f in fields )
442
+ def make_pattern (self , rustname , pyname , fieldname , fields ):
443
+ if fields :
444
+ header = f"{ pyname } ::{ fieldname } ({ rustname } { fieldname } "
445
+ footer = ")"
446
+ else :
447
+ header = f"{ pyname } ::{ fieldname } "
448
+ footer = ""
394
449
395
- def gen_construction (self , cons_path , fields , depth ):
396
- self .emit (f"Ok({ cons_path } {{" , depth )
450
+ body = "," .join (rust_field (f .name ) for f in fields )
451
+ return header , body , footer
452
+
453
+ def gen_construction (self , header , fields , footer , depth ):
454
+ self .emit (f"Ok({ header } {{" , depth )
397
455
for field in fields :
398
456
name = rust_field (field .name )
399
457
self .emit (f"{ name } : Foldable::fold({ name } , folder)?," , depth + 1 )
400
- self .emit (" })" , depth )
458
+ self .emit (f"}} { footer } )" , depth )
401
459
402
460
403
461
class FoldModuleVisitor (TypeInfoEmitVisitor ):
@@ -514,33 +572,36 @@ def visitType(self, type, depth=0):
514
572
self .visit (type .value , type .name , depth )
515
573
516
574
def visitSum (self , sum , name , depth ):
517
- enumname = get_rust_type (name )
575
+ rustname = enumname = get_rust_type (name )
518
576
if sum .attributes :
519
- enumname += "Kind"
577
+ rustname = enumname + "Kind"
520
578
521
- self .emit (f"impl NamedNode for ast::{ enumname } {{" , depth )
579
+ self .emit (f"impl NamedNode for ast::{ rustname } {{" , depth )
522
580
self .emit (f"const NAME: &'static str = { json .dumps (name )} ;" , depth + 1 )
523
581
self .emit ("}" , depth )
524
- self .emit (f"impl Node for ast::{ enumname } {{" , depth )
582
+ self .emit (f"impl Node for ast::{ rustname } {{" , depth )
525
583
self .emit (
526
584
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {" , depth + 1
527
585
)
528
586
self .emit ("match self {" , depth + 2 )
529
587
for variant in sum .types :
530
- self .constructor_to_object (variant , enumname , depth + 3 )
588
+ self .constructor_to_object (variant , enumname , rustname , depth + 3 )
531
589
self .emit ("}" , depth + 2 )
532
590
self .emit ("}" , depth + 1 )
533
591
self .emit (
534
592
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {" ,
535
593
depth + 1 ,
536
594
)
537
- self .gen_sum_fromobj (sum , name , enumname , depth + 2 )
595
+ self .gen_sum_fromobj (sum , name , enumname , rustname , depth + 2 )
538
596
self .emit ("}" , depth + 1 )
539
597
self .emit ("}" , depth )
540
598
541
- def constructor_to_object (self , cons , enumname , depth ):
542
- fields_pattern = self .make_pattern (cons .fields )
543
- self .emit (f"ast::{ enumname } ::{ cons .name } {{ { fields_pattern } }} => {{" , depth )
599
+ def constructor_to_object (self , cons , enumname , rustname , depth ):
600
+ self .emit (f"ast::{ rustname } ::{ cons .name } " , depth )
601
+ if cons .fields :
602
+ fields_pattern = self .make_pattern (cons .fields )
603
+ self .emit (f"( ast::{ enumname } { cons .name } {{ { fields_pattern } }} )" , depth )
604
+ self .emit (" => {" , depth )
544
605
self .make_node (cons .name , cons .fields , depth + 1 )
545
606
self .emit ("}" , depth )
546
607
@@ -586,15 +647,20 @@ def make_node(self, variant, fields, depth):
586
647
def make_pattern (self , fields ):
587
648
return "," .join (rust_field (f .name ) for f in fields )
588
649
589
- def gen_sum_fromobj (self , sum , sumname , enumname , depth ):
650
+ def gen_sum_fromobj (self , sum , sumname , enumname , rustname , depth ):
590
651
if sum .attributes :
591
652
self .extract_location (sumname , depth )
592
653
593
654
self .emit ("let _cls = _object.class();" , depth )
594
655
self .emit ("Ok(" , depth )
595
656
for cons in sum .types :
596
657
self .emit (f"if _cls.is(Node{ cons .name } ::static_type()) {{" , depth )
597
- self .gen_construction (f"{ enumname } ::{ cons .name } " , cons , sumname , depth + 1 )
658
+ if cons .fields :
659
+ self .emit (f"ast::{ rustname } ::{ cons .name } (ast::{ enumname } { cons .name } {{" , depth + 1 )
660
+ self .gen_construction_fields (cons , sumname , depth + 1 )
661
+ self .emit ("})" , depth + 1 )
662
+ else :
663
+ self .emit (f"ast::{ rustname } ::{ cons .name } " , depth + 1 )
598
664
self .emit ("} else" , depth )
599
665
600
666
self .emit ("{" , depth )
@@ -610,13 +676,16 @@ def gen_product_fromobj(self, product, prodname, structname, depth):
610
676
self .gen_construction (structname , product , prodname , depth + 1 )
611
677
self .emit (")" , depth )
612
678
613
- def gen_construction (self , cons_path , cons , name , depth ):
614
- self .emit (f"ast::{ cons_path } {{" , depth )
679
+ def gen_construction_fields (self , cons , name , depth ):
615
680
for field in cons .fields :
616
681
self .emit (
617
682
f"{ rust_field (field .name )} : { self .decode_field (field , name )} ," ,
618
683
depth + 1 ,
619
684
)
685
+
686
+ def gen_construction (self , cons_path , cons , name , depth ):
687
+ self .emit (f"ast::{ cons_path } {{" , depth )
688
+ self .gen_construction_fields (cons , name , depth + 1 )
620
689
self .emit ("}" , depth )
621
690
622
691
def extract_location (self , typename , depth ):
0 commit comments