@@ -905,13 +905,29 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
905
905
var ( a_reshape , a_free_dims , a_free_dims_static ) = _tensordot_reshape ( a , a_axes ) ;
906
906
var ( b_reshape , b_free_dims , b_free_dims_static ) = _tensordot_reshape ( b , b_axes , true ) ;
907
907
var ab_matmul = matmul ( a_reshape , b_reshape ) ;
908
- var dims = new List < int > ( ) ;
909
- dims . AddRange ( a_free_dims ) ;
910
- dims . AddRange ( b_free_dims ) ;
911
- if ( ab_matmul . shape . Equals ( dims ) )
912
- return ab_matmul ;
908
+ if ( a_free_dims is int [ ] a_free_dims_list && b_free_dims is int [ ] b_free_dims_list )
909
+ {
910
+ var total_free_dims = a_free_dims_list . Concat ( b_free_dims_list ) . ToArray ( ) ;
911
+ if ( ab_matmul . shape . IsFullyDefined && ab_matmul . shape . as_int_list ( ) . SequenceEqual ( total_free_dims ) )
912
+ {
913
+ return ab_matmul ;
914
+ }
915
+ else
916
+ {
917
+ return array_ops . reshape ( ab_matmul , ops . convert_to_tensor ( total_free_dims ) , name ) ;
918
+ }
919
+ }
913
920
else
914
- return array_ops . reshape ( ab_matmul , tf . constant ( dims . ToArray ( ) ) , name : name ) ;
921
+ {
922
+ var a_free_dims_tensor = ops . convert_to_tensor ( a_free_dims , dtype : dtypes . int32 ) ;
923
+ var b_free_dims_tensor = ops . convert_to_tensor ( b_free_dims , dtype : dtypes . int32 ) ;
924
+ var product = array_ops . reshape ( ab_matmul , array_ops . concat ( new [ ] { a_free_dims_tensor , b_free_dims_tensor } , 0 ) , name ) ;
925
+ if ( a_free_dims_static is not null && b_free_dims_static is not null )
926
+ {
927
+ product . shape = new Shape ( a_free_dims_static . Concat ( b_free_dims_static ) . ToArray ( ) ) ;
928
+ }
929
+ return product ;
930
+ }
915
931
} ) ;
916
932
}
917
933
@@ -927,14 +943,42 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
927
943
return ( Binding . range ( a . shape . ndim - axe , a . shape . ndim ) . ToArray ( ) ,
928
944
Binding . range ( 0 , axe ) . ToArray ( ) ) ;
929
945
}
930
- else
946
+ else if ( axes . rank == 1 )
931
947
{
948
+ if ( axes . shape [ 0 ] != 2 )
949
+ {
950
+ throw new ValueError ( $ "`axes` must be an integer or have length 2. Received { axes } .") ;
951
+ }
932
952
( int a_axe , int b_axe ) = ( axes [ 0 ] , axes [ 1 ] ) ;
933
953
return ( new [ ] { a_axe } , new [ ] { b_axe } ) ;
934
954
}
955
+ else if ( axes . rank == 2 )
956
+ {
957
+ if ( axes . shape [ 0 ] != 2 )
958
+ {
959
+ throw new ValueError ( $ "`axes` must be an integer or have length 2. Received { axes } .") ;
960
+ }
961
+ int [ ] a_axes = new int [ axes . shape [ 1 ] ] ;
962
+ int [ ] b_axes = new int [ axes . shape [ 1 ] ] ;
963
+ for ( int i = 0 ; i < a_axes . Length ; i ++ )
964
+ {
965
+ a_axes [ i ] = axes [ 0 , i ] ;
966
+ b_axes [ i ] = axes [ 1 , i ] ;
967
+ if ( a_axes [ i ] == - 1 || b_axes [ i ] == - 1 )
968
+ {
969
+ throw new ValueError ( $ "Different number of contraction axes `a` and `b`," +
970
+ $ "{ len ( a_axes ) } != { len ( b_axes ) } .") ;
971
+ }
972
+ }
973
+ return ( a_axes , b_axes ) ;
974
+ }
975
+ else
976
+ {
977
+ throw new ValueError ( $ "Invalid rank { axes . rank } to make tensor dot.") ;
978
+ }
935
979
}
936
980
937
- static ( Tensor , int [ ] , int [ ] ) _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
981
+ static ( Tensor , object , int [ ] ) _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
938
982
{
939
983
if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( int [ ] ) , typeof ( Tuple ) ) ) )
940
984
{
@@ -977,6 +1021,58 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
977
1021
var reshaped_a = array_ops . reshape ( a_trans , new_shape ) ;
978
1022
return ( reshaped_a , free_dims , free_dims ) ;
979
1023
}
1024
+ else
1025
+ {
1026
+ int [ ] free_dims_static ;
1027
+ Tensor converted_shape_a , converted_axes , converted_free ;
1028
+ if ( a . shape . ndim != - 1 )
1029
+ {
1030
+ var shape_a = a . shape . as_int_list ( ) ;
1031
+ for ( int i = 0 ; i < axes . Length ; i ++ )
1032
+ {
1033
+ if ( axes [ i ] < 0 )
1034
+ {
1035
+ axes [ i ] += shape_a . Length ;
1036
+ }
1037
+ }
1038
+ var free = Enumerable . Range ( 0 , shape_a . Length ) . Where ( i => ! axes . Contains ( i ) ) . ToArray ( ) ;
1039
+
1040
+ var axes_dims = axes . Select ( i => shape_a [ i ] ) ;
1041
+ var free_dims = free . Select ( i => shape_a [ i ] ) . ToArray ( ) ;
1042
+ free_dims_static = free_dims ;
1043
+ converted_axes = ops . convert_to_tensor ( axes , dtypes . int32 , "axes" ) ;
1044
+ converted_free = ops . convert_to_tensor ( free , dtypes . int32 , "free" ) ;
1045
+ converted_shape_a = array_ops . shape ( a ) ;
1046
+ }
1047
+ else
1048
+ {
1049
+ free_dims_static = null ;
1050
+ converted_shape_a = array_ops . shape ( a ) ;
1051
+ var rank_a = array_ops . rank ( a ) ;
1052
+ converted_axes = ops . convert_to_tensor ( axes , dtypes . int32 , "axes" ) ;
1053
+ converted_axes = array_ops . where_v2 ( converted_axes >= 0 , converted_axes , converted_axes + rank_a ) ;
1054
+ ( converted_free , var _ ) = gen_ops . list_diff ( gen_math_ops . range ( ops . convert_to_tensor ( 0 ) , rank_a , ops . convert_to_tensor ( 1 ) ) ,
1055
+ converted_axes , dtypes . int32 ) ;
1056
+ }
1057
+ var converted_free_dims = array_ops . gather ( converted_shape_a , converted_free ) ;
1058
+ var converted_axes_dims = array_ops . gather ( converted_shape_a , converted_axes ) ;
1059
+ var prod_free_dims = reduce_prod ( converted_free_dims ) ;
1060
+ var prod_axes_dims = reduce_prod ( converted_axes_dims ) ;
1061
+ Tensor reshaped_a ;
1062
+ if ( flipped )
1063
+ {
1064
+ var perm = array_ops . concat ( new [ ] { converted_axes , converted_free } , 0 ) ;
1065
+ var new_shape = array_ops . stack ( new [ ] { prod_axes_dims , prod_free_dims } ) ;
1066
+ reshaped_a = array_ops . reshape ( array_ops . transpose ( a , perm ) , new_shape ) ;
1067
+ }
1068
+ else
1069
+ {
1070
+ var perm = array_ops . concat ( new [ ] { converted_free , converted_axes } , 0 ) ;
1071
+ var new_shape = array_ops . stack ( new [ ] { prod_free_dims , prod_axes_dims } ) ;
1072
+ reshaped_a = array_ops . reshape ( array_ops . transpose ( a , perm ) , new_shape ) ;
1073
+ }
1074
+ return ( reshaped_a , converted_free_dims , free_dims_static ) ;
1075
+ }
980
1076
981
1077
throw new NotImplementedException ( "_tensordot_reshape" ) ;
982
1078
}
0 commit comments