@@ -651,13 +651,13 @@ object _get_input_tensor(int time)
651
651
states = Nest . PackSequenceAs ( states , flat_final_states ) . ToTensors ( ) ;
652
652
if ( return_all_outputs )
653
653
{
654
- successive_outputs . Add ( output ) ;
655
- successive_states . Add ( states ) ;
654
+ successive_outputs = successive_outputs . MergeWith ( output ) ;
655
+ successive_outputs = successive_states . MergeWith ( states ) ;
656
656
}
657
657
else
658
658
{
659
- successive_outputs = new Tensors { output } ;
660
- successive_states = new Tensors { states } ;
659
+ successive_outputs = new Tensors ( output ) ;
660
+ successive_states = new Tensors ( states ) ;
661
661
}
662
662
663
663
}
@@ -722,16 +722,11 @@ object _get_input_tensor(int time)
722
722
// Get the time(0) input and compute the output for that, the output will
723
723
// be used to determine the dtype of output tensor array. Don't read from
724
724
// input_ta due to TensorArray clear_after_read default to True.
725
- var inps = new Tensors ( ) ;
726
- foreach ( var inp in flatted_inptus )
727
- {
728
- inps . Add ( inp [ 0 ] ) ;
729
- }
730
- var input_time_zero = Nest . PackSequenceAs ( inputs , inps ) . ToTensors ( ) ;
725
+ var input_time_zero = Nest . PackSequenceAs ( inputs , flatted_inptus . Select ( x => x [ 0 ] ) . ToArray ( ) ) . ToTensors ( ) ;
731
726
732
727
// output_time_zero is used to determine the cell output shape and its
733
728
// dtype. the value is discarded.
734
- ( output_time_zero , _ ) = step_function ( ( Tensor ) input_time_zero ,
729
+ ( output_time_zero , _ ) = step_function ( input_time_zero ,
735
730
constants is null ? initial_states : initial_states . MergeWith ( constants ) ) ;
736
731
737
732
int output_ta_size = return_all_outputs ? time_steps_t : 1 ;
@@ -816,6 +811,7 @@ object _get_input_tensor(int time)
816
811
817
812
Func < Tensor , Tensor > cond = ( time ) => ( time < time_steps_t ) ;
818
813
int parallel_iterations = 32 ;
814
+ new_states = states ;
819
815
if ( masking_fn != null )
820
816
{
821
817
// Mask for the T output will be base on the output of T - 1. In the
@@ -846,7 +842,7 @@ RNN step function.
846
842
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
847
843
var current_input = Nest . PackSequenceAs ( inputs , flat_current_input ) . ToTensors ( ) ;
848
844
var mask_t = masking_fn ( time ) ;
849
- var ( output , new_states_internal ) = step_function ( current_input , states . MergeWith ( constants ) ) ;
845
+ var ( output , new_states_internal ) = step_function ( current_input , new_states . MergeWith ( constants ) ) ;
850
846
// mask output
851
847
var flat_output = Nest . Flatten ( output ) . ToList ( ) ;
852
848
@@ -871,11 +867,12 @@ RNN step function.
871
867
new_states_internal = Nest . PackSequenceAs ( new_states , flat_final_state ) . ToTensors ( ) ;
872
868
873
869
var ta_index_to_write = return_all_outputs ? time : tf . constant ( 0 ) ;
874
- // TODO(Wanglongzhi2001),deal with zip output_ta_t
875
- foreach ( var ( ta , Out ) in zip ( output_ta_t , flat_new_output ) )
870
+ output_ta_t = zip ( output_ta_t , flat_new_output ) . Select ( item =>
876
871
{
877
- output_ta_t . Add ( ta . write ( ta_index_to_write , Out ) ) ;
878
- }
872
+ var ( ta , out_ ) = item ;
873
+ return ta . write ( ta_index_to_write , out_ ) ;
874
+ } ) . ToList ( ) ;
875
+
879
876
880
877
new_states_internal = Nest . PackSequenceAs ( initial_states , flat_new_state ) . ToTensors ( ) ;
881
878
@@ -921,15 +918,8 @@ Tensor _step(Tensor time)
921
918
}
922
919
var final_outputs = tf . while_loop ( cond : cond , body : _step , loop_vars : time , parallel_iterations : parallel_iterations ) ;
923
920
}
924
- //Tensors outputs = new Tensors();
925
- foreach ( var o in output_ta )
926
- {
927
- outputs . Add ( o . stack ( ) ) ;
928
- }
929
- foreach ( var o in outputs )
930
- {
931
- last_output . Add ( o [ - 1 ] ) ;
932
- }
921
+ outputs = outputs . MergeWith ( output_ta . Select ( o => o . stack ( ) ) . ToTensors ( ) ) ;
922
+ last_output = last_output . MergeWith ( outputs . Select ( o => o [ - 1 ] ) . ToTensors ( ) ) ;
933
923
outputs = Nest . PackSequenceAs ( output_time_zero , outputs ) . ToTensors ( ) ;
934
924
last_output = Nest . PackSequenceAs ( output_time_zero , last_output ) . ToTensors ( ) ;
935
925
0 commit comments