3
3
using Tensorflow . Keras . ArgsDefinition ;
4
4
using static Tensorflow . Binding ;
5
5
using Tensorflow . Keras . Utils ;
6
+ using Tensorflow . Util ;
7
+ using Tensorflow . Framework ;
6
8
7
9
namespace Tensorflow . Keras . Engine . DataAdapters
8
10
{
@@ -24,6 +26,7 @@ public class DataHandler
24
26
long _steps_per_execution_value ;
25
27
int _initial_epoch => args . InitialEpoch ;
26
28
int _epochs => args . Epochs ;
29
+ NDArray _sample_weight => args . SampleWeight ;
27
30
IVariableV1 _steps_per_execution ;
28
31
29
32
public DataHandler ( DataHandlerArgs args )
@@ -75,10 +78,75 @@ public DataHandler(DataHandlerArgs args)
75
78
}
76
79
77
80
_dataset = _adapter . GetDataset ( ) ;
78
- _inferred_steps = _infer_steps ( args . StepsPerEpoch , _dataset ) ;
79
81
_current_step = 0 ;
80
82
_step_increment = _steps_per_execution_value - 1 ;
81
83
_insufficient_data = false ;
84
+ _configure_dataset_and_inferred_steps ( args . X , args . ClassWeight ) ;
85
+ }
86
+
87
+ void _configure_dataset_and_inferred_steps ( Tensors x , Dictionary < int , float > class_weight )
88
+ {
89
+ if ( _dataset == null )
90
+ {
91
+ _dataset = _adapter . GetDataset ( ) ;
92
+ _inferred_steps = _infer_steps ( args . StepsPerEpoch , _dataset ) ;
93
+ }
94
+
95
+ if ( class_weight != null )
96
+ {
97
+ _dataset = _dataset . map ( _make_class_weight_map_fn ( class_weight ) ) ;
98
+ }
99
+ _inferred_steps = _infer_steps ( args . StepsPerEpoch , _dataset ) ;
100
+ }
101
+
102
+
103
+ Func < Tensors , Tensors > _make_class_weight_map_fn ( Dictionary < int , float > class_weight )
104
+ {
105
+ var class_ids = class_weight . Keys . OrderBy ( key => key ) . ToList ( ) ;
106
+ var expected_class_ids = range ( class_ids [ 0 ] , class_ids [ class_ids . Count - 1 ] + 1 ) ;
107
+ if ( ! class_ids . SequenceEqual ( expected_class_ids ) )
108
+ {
109
+ throw new ValueError ( "Expected `class_weight` to be a dict with keys from 0 to one less " +
110
+ $ "than the number of classes, found { class_weight } ") ;
111
+ }
112
+
113
+ var class_weight_list = new List < float > ( ) ;
114
+ foreach ( var class_id in class_ids )
115
+ {
116
+ class_weight_list . Add ( class_weight [ class_id ] ) ;
117
+ }
118
+ var class_weight_tensor = tf . convert_to_tensor ( class_weight_list . ToArray ( ) ) ;
119
+
120
+ Func < Tensors , Tensors > _class_weight_map_fn = ( Tensors data ) =>
121
+ {
122
+ var x = data [ 0 ] ;
123
+ var y = data [ 1 ] ;
124
+ var sw = _sample_weight == null ? null : ops . convert_to_tensor ( _sample_weight ) ;
125
+
126
+ if ( y . shape . rank > 2 )
127
+ {
128
+ throw new ValueError ( "`class_weight` not supported for 3+ dimensional targets." ) ;
129
+ }
130
+
131
+ var y_classes = smart_module . smart_cond (
132
+ y . shape . rank == 2 && y . shape [ 1 ] > 1 ,
133
+ ( ) => math_ops . argmax ( y , dimension : 1 ) ,
134
+ ( ) => math_ops . cast ( tf . reshape ( y , ( - 1 ) ) , TF_DataType . TF_INT64 ) ) ;
135
+
136
+ var cw = array_ops . gather ( class_weight_tensor , y_classes ) ;
137
+ if ( sw != null )
138
+ {
139
+ cw = tf . cast ( cw , sw . dtype ) ;
140
+ cw *= sw ;
141
+ }
142
+ else
143
+ {
144
+ sw = cw ;
145
+ }
146
+ return new Tensors { x , y , sw } ;
147
+ } ;
148
+
149
+ return _class_weight_map_fn ;
82
150
}
83
151
84
152
long _infer_steps ( int steps_per_epoch , IDatasetV2 dataset )
0 commit comments