Skip to content

Commit d4f1c34

Browse files
committed
fix _apply_dense for Optimizer.
1 parent d3b681e commit d4f1c34

31 files changed

+299
-108
lines changed

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ public Tensor dense(Tensor inputs,
193193
Name = name
194194
});
195195

196-
throw new NotImplementedException("");
197-
//return layer.apply(inputs).Item1;
196+
return layer.Apply(inputs);
198197
}
199198

200199
/// <summary>

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ public Tensor dropout(Tensor x, Tensor keep_prob = null, Tensor noise_shape = nu
6666
Tensor keep = null;
6767
if (keep_prob != null)
6868
keep = 1.0f - keep_prob;
69-
70-
return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name);
69+
var rate_tensor = rate.HasValue ? tf.constant(rate.Value) : keep;
70+
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name);
7171
}
7272

7373
/// <summary>

src/TensorFlowNET.Core/Framework/meta_graph.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public static (Dictionary<string, IVariableV1>, ITensorOrOperation[]) import_sco
150150
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
151151
scope: scope_to_prepend_to_names);
152152
var var_list = new Dictionary<string, IVariableV1>();
153-
// variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
153+
variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
154154

155155
return (var_list, imported_return_elements);
156156
}
@@ -277,6 +277,11 @@ private static void add_collection_def(MetaGraphDef meta_graph_def,
277277
var proto = x_ref_var.to_proto(export_scope);
278278
col_def.BytesList.Value.Add(proto.ToByteString());
279279
}
280+
else if(x is ResourceVariable x_res_var)
281+
{
282+
var proto = x_res_var.to_proto(export_scope);
283+
col_def.BytesList.Value.Add(proto.ToByteString());
284+
}
280285
}
281286
break;
282287
case List<RefVariable> collection_list:

src/TensorFlowNET.Core/Functions/c_api.function.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,23 @@ public partial class c_api
3131
/// <param name="output_func_def"></param>
3232
/// <param name="status"></param>
3333
[DllImport(TensorFlowLibName)]
34-
public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, SafeStatusHandle status);
34+
public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status);
3535

36+
[DllImport(TensorFlowLibName)]
37+
public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name,
38+
bool append_hash_to_fn_name,
39+
int num_opers, IntPtr[] opers,
40+
int ninputs, TF_Output[] inputs,
41+
int noutputs, TF_Output[] outputs,
42+
IntPtr output_names,
43+
IntPtr opts,
44+
string description,
45+
SafeStatusHandle status);
46+
47+
[DllImport(TensorFlowLibName)]
48+
public static extern IntPtr TF_FunctionName(IntPtr func);
3649

50+
[DllImport(TensorFlowLibName)]
51+
public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status);
3752
}
3853
}

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,9 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
327327
var output_shape = op.outputs[0]._shape_tuple();
328328

329329
Tensor result, factor_tensor;
330-
if(input_shape != null &&
331-
output_shape != null)
330+
if(tf.executing_eagerly()
331+
&& input_shape != null
332+
&& output_shape != null)
332333
{
333334
var input_size = np.prod(input_shape);
334335
var output_size = np.prod(output_shape);
@@ -339,11 +340,7 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
339340
{
340341
var input_shape_tensor = array_ops.shape(op.inputs[0]);
341342
var output_shape_tensor = array_ops.shape(op.outputs[0]);
342-
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor));
343-
throw new NotImplementedException("");
344-
#pragma warning disable CS0162 // Unreachable code detected
345-
factor_tensor = null;
346-
#pragma warning restore CS0162 // Unreachable code detected
343+
factor_tensor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor));
347344
}
348345

349346
result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype));

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Te
128128
[RegisterGradient("Conv2D")]
129129
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
130130
{
131-
var dilations = op.get_attr<int[]>("dilations");
132-
var strides = op.get_attr<int[]>("strides");
131+
var dilations = op.get_attr_list<int>("dilations");
132+
var strides = op.get_attr_list<int>("strides");
133133
var padding = op.get_attr<string>("padding");
134-
var explicit_paddings = op.get_attr<int[]>("explicit_paddings");
134+
var explicit_paddings = op.get_attr_list<int>("explicit_paddings");
135135
var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu");
136136
var data_format = op.get_attr<string>("data_format");
137137
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
@@ -287,8 +287,8 @@ public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads)
287287
op.inputs[0],
288288
op.outputs[0],
289289
grad,
290-
op.get_attr("ksize") as int[],
291-
op.get_attr("strides") as int[],
290+
op.get_attr_list<int>("ksize"),
291+
op.get_attr_list<int>("strides"),
292292
padding: op.get_attr("padding").ToString(),
293293
data_format: op.get_attr("data_format").ToString())
294294
};

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,6 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
293293

294294
_create_op_helper(op, compute_device);
295295

296-
/*Console.Write($"create_op: {op_type} '{node_def.Name}'");
297-
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
298-
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
299-
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
300-
Console.WriteLine();*/
301-
302296
return op;
303297
}
304298

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public partial class c_api
139139
/// <param name="status">TF_Status*</param>
140140
[DllImport(TensorFlowLibName)]
141141
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);
142-
142+
143143
/// <summary>
144144
/// Returns the number of dimensions of the Tensor referenced by `output`
145145
/// in `graph`.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class TensorLikeDataAdapterArgs
8+
{
9+
public Tensor X { get; set; }
10+
public Tensor Y { get; set; }
11+
public int BatchSize { get; set; }
12+
public int Steps { get; set; }
13+
public int Epochs { get; set; }
14+
public bool Shuffle { get; set; }
15+
}
16+
}

src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ public class DataHandler
2727

2828
public DataHandler(DataHandlerArgs args)
2929
{
30+
this.args = args;
3031

32+
var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { });
3133
}
3234
}
3335
}

src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
45
using static Tensorflow.Binding;
56

67
namespace Tensorflow.Keras.Engine.DataAdapters
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
1011
/// </summary>
1112
public class TensorLikeDataAdapter : IDataAdapter
1213
{
13-
public TensorLikeDataAdapter()
14+
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
1415
{
1516
tf.data.Dataset.range(5);
1617
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Engine
6+
{
7+
public partial class Layer
8+
{
9+
Dictionary<Layer, object> trainable_state;
10+
Dictionary<Layer, object> _get_trainable_state()
11+
{
12+
trainable_state = new Dictionary<Layer, object>();
13+
throw new NotImplementedException("");
14+
}
15+
16+
void _set_trainable_state(Dictionary<Layer, object> trainable_state)
17+
{
18+
throw new NotImplementedException("");
19+
}
20+
}
21+
}

src/TensorFlowNET.Core/Keras/Engine/Model.cs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using NumSharp;
1+
using static Tensorflow.Binding;
22
using System;
33
using Tensorflow.Keras.ArgsDefinition;
4+
using Tensorflow.Keras.Engine.DataAdapters;
45
using Tensorflow.Keras.Losses;
56
using Tensorflow.Keras.Optimizers;
67

@@ -21,6 +22,7 @@ public class Model : Layer
2122
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
2223
string loss;
2324
IOptimizer optimizer;
25+
IVariableV1 _steps_per_execution;
2426

2527
public Model(ModelArgs args)
2628
: base(args)
@@ -37,10 +39,25 @@ public void compile(string optimizerName, string lossName)
3739
break;
3840
}
3941

42+
int experimental_steps_per_execution = 1;
43+
_configure_steps_per_execution(experimental_steps_per_execution);
44+
45+
_reset_compile_cache();
46+
4047
loss = lossName;
4148
_is_compiled = true;
49+
}
50+
51+
void _configure_steps_per_execution(int steps_per_execution)
52+
{
53+
_steps_per_execution = tf.Variable(steps_per_execution,
54+
dtype: TF_DataType.TF_INT64,
55+
aggregation: VariableAggregation.OnlyFirstReplica);
56+
}
57+
58+
void _reset_compile_cache()
59+
{
4260

43-
// Prepare list of loss functions, same size of model outputs.
4461
}
4562

4663
public void compile(string optimizerName, ILossFunc lossName)
@@ -70,6 +87,20 @@ public Tensor predict(Tensor x,
7087
int workers = 1,
7188
bool use_multiprocessing = false)
7289
{
90+
var data_handler = new DataHandler(new DataHandlerArgs
91+
{
92+
X = x,
93+
BatchSize = batch_size,
94+
StepsPerEpoch = steps,
95+
InitialEpoch = 0,
96+
Epochs = 1,
97+
MaxQueueSize = max_queue_size,
98+
Workers = workers,
99+
UseMultiprocessing = use_multiprocessing,
100+
Model = this,
101+
StepsPerExecution = _steps_per_execution
102+
});
103+
73104
throw new NotImplementedException("");
74105
}
75106
}

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Linq;
1718
using Tensorflow.Keras.ArgsDefinition;
1819
using Tensorflow.Keras.Engine;
1920
using static Tensorflow.Binding;
@@ -44,6 +45,9 @@ public Embedding(EmbeddingArgs args)
4445
if (args.InputShape == null)
4546
args.InputShape = args.InputLength;
4647

48+
if (args.BatchInputShape == null)
49+
args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray();
50+
4751
embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer;
4852
SupportsMasking = mask_zero;
4953
}

src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ public Dense Dense(int units,
3434

3535
/// <summary>
3636
/// Turns positive integers (indexes) into dense vectors of fixed size.
37+
/// This layer can only be used as the first layer in a model.
38+
/// e.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
39+
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
3740
/// </summary>
38-
/// <param name="input_dim"></param>
39-
/// <param name="output_dim"></param>
40-
/// <param name="embeddings_initializer"></param>
41+
/// <param name="input_dim">Size of the vocabulary, i.e. maximum integer index + 1.</param>
42+
/// <param name="output_dim">Dimension of the dense embedding.</param>
43+
/// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param>
4144
/// <param name="mask_zero"></param>
4245
/// <returns></returns>
4346
public Embedding Embedding(int input_dim,

src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ public TruncatedNormal(float mean = 0.0f,
3636

3737
public Tensor Apply(InitializerArgs args)
3838
{
39-
if (args.DType == TF_DataType.DtInvalid)
40-
args.DType = this.dtype;
41-
return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed);
39+
if (args.DType != TF_DataType.DtInvalid)
40+
dtype = args.DType;
41+
return random_ops.truncated_normal(args.Shape, mean, stddev, dtype: dtype, seed: seed);
4242
}
4343
}
4444
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,35 @@ private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, Map
230230
public virtual T get_attr<T>(string name)
231231
=> (T)get_attr(name);
232232

233+
public virtual T[] get_attr_list<T>(string name)
234+
{
235+
if (tf.executing_eagerly())
236+
return (T[])get_attr(name);
237+
238+
AttrValue x = null;
239+
240+
lock (Locks.ProcessWide)
241+
{
242+
using var buf = new Buffer();
243+
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
244+
tf.Status.Check(true);
245+
246+
x = AttrValue.Parser.ParseFrom(buf.DangerousMemoryBlock.Stream());
247+
}
248+
249+
string oneof_value = x.ValueCase.ToString();
250+
if (string.IsNullOrEmpty(oneof_value))
251+
return null;
252+
253+
switch (typeof(T).Name)
254+
{
255+
case nameof(Int32):
256+
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray();
257+
default:
258+
return null;
259+
}
260+
}
261+
233262
public virtual object get_attr(string name)
234263
{
235264
AttrValue x = null;
@@ -250,7 +279,7 @@ public virtual object get_attr(string name)
250279
if (oneof_value == "list")
251280
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
252281

253-
if (oneof_value == "type")
282+
if (string.Equals("type", oneof_value, StringComparison.OrdinalIgnoreCase))
254283
return x.Type;
255284

256285
object result = x.GetType().GetProperty(oneof_value).GetValue(x);

0 commit comments

Comments
 (0)