@@ -510,7 +510,7 @@ Tensor swap_batch_timestep(Tensor input_t)
510
510
}
511
511
512
512
}
513
-
513
+
514
514
// tf.where needs its condition tensor to be the same shape as its two
515
515
// result tensors, but in our case the condition (mask) tensor is
516
516
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -535,7 +535,7 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
535
535
{
536
536
mask_t = tf . expand_dims ( mask_t , - 1 ) ;
537
537
}
538
- var multiples = Enumerable . Repeat ( 1 , fixed_dim ) . ToArray ( ) . concat ( input_t . shape . as_int_list ( ) . ToList ( ) . GetRange ( fixed_dim , input_t . rank ) ) ;
538
+ var multiples = Enumerable . Repeat ( 1 , fixed_dim ) . ToArray ( ) . concat ( input_t . shape . as_int_list ( ) . Skip ( fixed_dim ) . ToArray ( ) ) ;
539
539
return tf . tile ( mask_t , multiples ) ;
540
540
}
541
541
@@ -570,9 +570,6 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
570
570
// individually. The result of this will be a tuple of lists, each of
571
571
// the item in tuple is list of the tensor with shape (batch, feature)
572
572
573
-
574
-
575
-
576
573
Tensors _process_single_input_t ( Tensor input_t )
577
574
{
578
575
var unstaked_input_t = array_ops . unstack ( input_t ) ; // unstack for time_step dim
@@ -609,7 +606,7 @@ object _get_input_tensor(int time)
609
606
var mask_list = tf . unstack ( mask ) ;
610
607
if ( go_backwards )
611
608
{
612
- mask_list . Reverse ( ) ;
609
+ mask_list . Reverse ( ) . ToArray ( ) ;
613
610
}
614
611
615
612
for ( int i = 0 ; i < time_steps ; i ++ )
@@ -629,9 +626,10 @@ object _get_input_tensor(int time)
629
626
}
630
627
else
631
628
{
632
- prev_output = successive_outputs [ successive_outputs . Length - 1 ] ;
629
+ prev_output = successive_outputs . Last ( ) ;
633
630
}
634
631
632
+ // output could be a tensor
635
633
output = tf . where ( tiled_mask_t , output , prev_output ) ;
636
634
637
635
var flat_states = Nest . Flatten ( states ) . ToList ( ) ;
@@ -661,13 +659,13 @@ object _get_input_tensor(int time)
661
659
}
662
660
663
661
}
664
- last_output = successive_outputs [ successive_outputs . Length - 1 ] ;
665
- new_states = successive_states [ successive_states . Length - 1 ] ;
662
+ last_output = successive_outputs . Last ( ) ;
663
+ new_states = successive_states . Last ( ) ;
666
664
outputs = tf . stack ( successive_outputs ) ;
667
665
668
666
if ( zero_output_for_mask )
669
667
{
670
- last_output = tf . where ( _expand_mask ( mask_list [ mask_list . Length - 1 ] , last_output ) , last_output , tf . zeros_like ( last_output ) ) ;
668
+ last_output = tf . where ( _expand_mask ( mask_list . Last ( ) , last_output ) , last_output , tf . zeros_like ( last_output ) ) ;
671
669
outputs = tf . where ( _expand_mask ( mask , outputs , fixed_dim : 2 ) , outputs , tf . zeros_like ( outputs ) ) ;
672
670
}
673
671
else // mask is null
@@ -689,8 +687,8 @@ object _get_input_tensor(int time)
689
687
successive_states = new Tensors { newStates } ;
690
688
}
691
689
}
692
- last_output = successive_outputs [ successive_outputs . Length - 1 ] ;
693
- new_states = successive_states [ successive_states . Length - 1 ] ;
690
+ last_output = successive_outputs . Last ( ) ;
691
+ new_states = successive_states . Last ( ) ;
694
692
outputs = tf . stack ( successive_outputs ) ;
695
693
}
696
694
}
@@ -701,6 +699,8 @@ object _get_input_tensor(int time)
701
699
// Create input tensor array, if the inputs is nested tensors, then it
702
700
// will be flattened first, and tensor array will be created one per
703
701
// flattened tensor.
702
+
703
+
704
704
var input_ta = new List < TensorArray > ( ) ;
705
705
for ( int i = 0 ; i < flatted_inptus . Count ; i ++ )
706
706
{
@@ -719,6 +719,7 @@ object _get_input_tensor(int time)
719
719
}
720
720
}
721
721
722
+
722
723
// Get the time(0) input and compute the output for that, the output will
723
724
// be used to determine the dtype of output tensor array. Don't read from
724
725
// input_ta due to TensorArray clear_after_read default to True.
@@ -773,7 +774,7 @@ object _get_input_tensor(int time)
773
774
return res ;
774
775
} ;
775
776
}
776
- // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
777
+ // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
777
778
else if ( input_length is Tensor )
778
779
{
779
780
if ( go_backwards )
0 commit comments