Skip to content

Commit 43c3705

Browse files
authored
Merge pull request #1189 from Wanglongzhi2001/master
feat: add the implementation of class_weight in model.fit
2 parents f16902d + a1c64ef commit 43c3705

File tree

3 files changed

+84
-10
lines changed

3 files changed

+84
-10
lines changed

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

+69-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using Tensorflow.Keras.ArgsDefinition;
44
using static Tensorflow.Binding;
55
using Tensorflow.Keras.Utils;
6+
using Tensorflow.Util;
7+
using Tensorflow.Framework;
68

79
namespace Tensorflow.Keras.Engine.DataAdapters
810
{
@@ -24,6 +26,7 @@ public class DataHandler
2426
long _steps_per_execution_value;
2527
int _initial_epoch => args.InitialEpoch;
2628
int _epochs => args.Epochs;
29+
NDArray _sample_weight => args.SampleWeight;
2730
IVariableV1 _steps_per_execution;
2831

2932
public DataHandler(DataHandlerArgs args)
@@ -75,10 +78,75 @@ public DataHandler(DataHandlerArgs args)
7578
}
7679

7780
_dataset = _adapter.GetDataset();
78-
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
7981
_current_step = 0;
8082
_step_increment = _steps_per_execution_value - 1;
8183
_insufficient_data = false;
84+
_configure_dataset_and_inferred_steps(args.X, args.ClassWeight);
85+
}
86+
87+
void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight)
88+
{
89+
if (_dataset == null)
90+
{
91+
_dataset = _adapter.GetDataset();
92+
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
93+
}
94+
95+
if (class_weight != null)
96+
{
97+
_dataset = _dataset.map(_make_class_weight_map_fn(class_weight));
98+
}
99+
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
100+
}
101+
102+
103+
Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight)
104+
{
105+
var class_ids = class_weight.Keys.OrderBy(key => key).ToList();
106+
var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1);
107+
if (!class_ids.SequenceEqual(expected_class_ids))
108+
{
109+
throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+
110+
$"than the number of classes, found {class_weight}");
111+
}
112+
113+
var class_weight_list = new List<float>();
114+
foreach (var class_id in class_ids)
115+
{
116+
class_weight_list.Add(class_weight[class_id]);
117+
}
118+
var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray());
119+
120+
Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) =>
121+
{
122+
var x = data[0];
123+
var y = data[1];
124+
var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight);
125+
126+
if (y.shape.rank > 2)
127+
{
128+
throw new ValueError("`class_weight` not supported for 3+ dimensional targets.");
129+
}
130+
131+
var y_classes = smart_module.smart_cond(
132+
y.shape.rank == 2 && y.shape[1] > 1,
133+
() => math_ops.argmax(y, dimension: 1),
134+
() => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64));
135+
136+
var cw = array_ops.gather(class_weight_tensor, y_classes);
137+
if (sw != null)
138+
{
139+
cw = tf.cast(cw, sw.dtype);
140+
cw *= sw;
141+
}
142+
else
143+
{
144+
sw = cw;
145+
}
146+
return new Tensors { x, y, sw };
147+
};
148+
149+
return _class_weight_map_fn;
82150
}
83151

84152
long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+11-2
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,20 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
164164
}
165165

166166

167-
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
167+
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
168+
{
169+
(x,y) = data_handler.DataAdapter.Expand1d(x, y);
170+
var y_pred = Apply(x, training: false);
171+
var loss = compiled_loss.Call(y, y_pred);
172+
compiled_metrics.update_state(y, y_pred);
173+
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
174+
}
175+
176+
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight)
168177
{
169178
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
170179
var y_pred = Apply(x, training: false);
171-
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
180+
var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight);
172181
compiled_metrics.update_state(y, y_pred);
173182
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
174183
}

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+4-7
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ public ICallback fit(NDArray x, NDArray y,
6363
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
6464
}
6565

66-
// TODO(Wanglongzhi2001)
67-
if (class_weight != null)
68-
{
69-
throw new NotImplementedException("class_weight is not implemented");
70-
}
71-
7266
var data_handler = new DataHandler(new DataHandlerArgs
7367
{
7468
X = x,
@@ -78,6 +72,7 @@ public ICallback fit(NDArray x, NDArray y,
7872
InitialEpoch = initial_epoch,
7973
Epochs = epochs,
8074
Shuffle = shuffle,
75+
ClassWeight = class_weight,
8176
MaxQueueSize = max_queue_size,
8277
Workers = workers,
8378
UseMultiprocessing = use_multiprocessing,
@@ -126,11 +121,12 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
126121
{
127122
X = new Tensors(x.ToArray()),
128123
Y = y,
124+
SampleWeight = sample_weight,
129125
BatchSize = batch_size,
130126
InitialEpoch = initial_epoch,
131127
Epochs = epochs,
132128
Shuffle = shuffle,
133-
SampleWeight = sample_weight,
129+
ClassWeight = class_weight,
134130
MaxQueueSize = max_queue_size,
135131
Workers = workers,
136132
UseMultiprocessing = use_multiprocessing,
@@ -174,6 +170,7 @@ public History fit(IDatasetV2 dataset,
174170
InitialEpoch = initial_epoch,
175171
Epochs = epochs,
176172
Shuffle = shuffle,
173+
SampleWeight = sample_weight,
177174
MaxQueueSize = max_queue_size,
178175
Workers = workers,
179176
UseMultiprocessing = use_multiprocessing,

0 commit comments

Comments
 (0)