Skip to content

Commit 9cd8681

Browse files
authored
Merge pull request #1140 from dogvane/master
fix same bug
2 parents 8574881 + 7165304 commit 9cd8681

File tree

8 files changed

+53
-6
lines changed

8 files changed

+53
-6
lines changed

src/TensorFlowNET.Core/Gradients/nn_grad.cs

+17
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,23 @@ public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads)
365365
};
366366
}
367367

368+
[RegisterGradient("AvgPool")]
369+
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads)
370+
{
371+
Tensor grad = grads[0];
372+
373+
return new Tensor[]
374+
{
375+
gen_nn_ops.avg_pool_grad(
376+
array_ops.shape(op.inputs[0]),
377+
grad,
378+
op.get_attr_list<int>("ksize"),
379+
op.get_attr_list<int>("strides"),
380+
op.get_attr("padding").ToString(),
381+
op.get_attr("data_format").ToString())
382+
};
383+
}
384+
368385
/// <summary>
369386
/// Return the gradients for TopK.
370387
/// </summary>

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1515
List<ILayer> Layers { get; }
1616
List<INode> InboundNodes { get; }
1717
List<INode> OutboundNodes { get; }
18-
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null);
18+
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null);
1919
List<IVariableV1> TrainableVariables { get; }
2020
List<IVariableV1> TrainableWeights { get; }
2121
List<IVariableV1> NonTrainableWeights { get; }

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_Data
145145
throw new NotImplementedException("_zero_state_tensors");
146146
}
147147

148-
public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null)
148+
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null)
149149
{
150150
throw new NotImplementedException();
151151
}

src/TensorFlowNET.Keras/Engine/Layer.Apply.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public partial class Layer
1313
/// <param name="state"></param>
1414
/// <param name="training"></param>
1515
/// <returns></returns>
16-
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null)
16+
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null)
1717
{
1818
if (callContext.Value == null)
1919
callContext.Value = new CallContext();

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ public History fit(IDatasetV2 dataset,
142142
int verbose = 1,
143143
List<ICallback> callbacks = null,
144144
IDatasetV2 validation_data = null,
145+
int validation_step = 10, // 间隔多少次会进行一次验证
145146
bool shuffle = true,
146147
int initial_epoch = 0,
147148
int max_queue_size = 10,
@@ -164,11 +165,11 @@ public History fit(IDatasetV2 dataset,
164165
});
165166

166167

167-
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
168+
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
168169
train_step_func: train_step_function);
169170
}
170171

171-
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
172+
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
172173
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
173174
{
174175
stop_training = false;
@@ -207,6 +208,9 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
207208

208209
if (validation_data != null)
209210
{
211+
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
212+
continue;
213+
210214
var val_logs = evaluate(validation_data);
211215
foreach(var log in val_logs)
212216
{

src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
393393
}
394394
}
395395

396-
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
396+
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null)
397397
{
398398
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
399399
if (optional_args is not null && rnn_optional_args is null)

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public IDatasetV2 image_dataset_from_directory(string directory,
5858
if (shuffle)
5959
dataset = dataset.shuffle(batch_size * 8, seed: seed);
6060
dataset = dataset.batch(batch_size);
61+
dataset.class_names = class_name_list;
6162
return dataset;
6263
}
6364

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

+25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,31 @@ namespace Tensorflow.Keras
66
{
77
public partial class Preprocessing
88
{
9+
10+
/// <summary>
11+
/// 图片路径转为数据处理用的dataset
12+
/// </summary>
13+
/// <param name="image_paths"></param>
14+
/// <param name="image_size"></param>
15+
/// <param name="num_channels"></param>
16+
/// <param name="interpolation">
17+
/// 用于调整大小的插值方法。支持`bilinear`、`nearest`、`bicubic`、`area`、`lanczos3`、`lanczos5`、`gaussian`、`mitchellcubic`。
18+
/// 默认为`'bilinear'`。
19+
/// </param>
20+
/// <returns></returns>
21+
public IDatasetV2 paths_to_dataset(string[] image_paths,
22+
Shape image_size,
23+
int num_channels = 3,
24+
int num_classes = 6,
25+
string interpolation = "bilinear")
26+
{
27+
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
28+
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));
29+
var label_ds = dataset_utils.labels_to_dataset(new int[num_classes] , "", num_classes);
30+
31+
return img_ds;
32+
}
33+
934
public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths,
1035
Shape image_size,
1136
int num_channels,

0 commit comments

Comments
 (0)