Skip to content

Commit f3ec499

Browse files
committed
fix _apply_sparse for ResourceVariable.
1 parent a883898 commit f3ec499

File tree

11 files changed

+39
-67
lines changed

11 files changed

+39
-67
lines changed

src/TensorFlowNET.Core/APIs/tf.ops.cs

+1-4
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ public void add_to_collections<T>(List<string> names, T value)
3030
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
3131
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
3232

33-
public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
34-
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
35-
36-
public Tensor assign(ResourceVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
33+
public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
3734
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
3835

3936
public void device(string device_name)

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public Layer(LayerArgs args)
121121
/// <param name="input"></param>
122122
/// <param name="is_training"></param>
123123
/// <returns></returns>
124-
public Tensor Apply(Tensor inputs, bool is_training = false)
124+
public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null)
125125
{
126126
Tensor outputs = null;
127127

@@ -135,9 +135,9 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
135135

136136
string nameScope = "";
137137
if (eager)
138-
{
139138
nameScope = name;
140-
}
139+
else
140+
nameScope = _name_scope();
141141

142142
// using var graph = tf.keras.backend.get_graph().as_default();
143143
if (!inputs.IsEagerTensor)
@@ -148,7 +148,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false)
148148
if (!built)
149149
MaybeBuild(inputs);
150150

151-
outputs = call(inputs, is_training: is_training);
151+
outputs = call(inputs, is_training: is_training, state: state);
152152

153153
outputs = _set_connectivity_metadata_(inputs, outputs);
154154
_handle_activity_regularization(inputs, outputs);

src/TensorFlowNET.Core/Layers/Layer.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ public Tensor[] __call__(Tensor inputs,
8888
{
8989
_current_scope = scope2;
9090
// Actually call layer
91-
outputs = base.Apply(inputs);
91+
outputs = base.Apply(inputs,
92+
is_training: training == null ? false : false,
93+
state: state);
9294
});
9395

9496

src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ protected override Tensor call(Tensor inputs, bool is_training = false, Tensor s
7171
{
7272
// Most basic RNN: output = new_state = act(W * input + U * state + B).
7373
var concat = array_ops.concat(new[] { inputs, state }, 1);
74-
var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable);
75-
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
74+
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor());
75+
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
7676
var output = _activation(gate_inputs, null);
7777
return output;
7878
}

src/TensorFlowNET.Core/Operations/Operation.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ public void _update_input(int index, Tensor tensor)
326326
// the updated inputs are reloaded from the c_api
327327
lock (Locks.ProcessWide)
328328
{
329-
c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
329+
// c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
330330
//var updated_inputs = inputs;
331331
tf.Status.Check();
332332
}

src/TensorFlowNET.Core/Operations/embedding_ops.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public static Tensor _embedding_lookup_and_transform(IVariableV1 @params,
7474
ids = ops.convert_to_tensor(ids, name: "ids");
7575
if (np == 1)
7676
{
77-
var gather = array_ops.gather(@params, ids, name: name);
77+
var gather = array_ops.gather(@params.AsTensor(), ids, name: name);
7878
var result = _clip(gather, ids, max_norm);
7979

8080
return array_ops.identity(result);

src/TensorFlowNET.Core/Operations/math_ops.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,12 @@ public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
706706
=> tf_with(ops.name_scope(name, "Pow", new { x, y }), scope =>
707707
{
708708
name = scope;
709-
var x_tensor = ops.convert_to_tensor(x, name: "x");
710-
var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());
711709

712710
if (tf.executing_eagerly())
713711
{
712+
var x_tensor = ops.convert_to_tensor(x, name: "x");
713+
var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());
714+
714715
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
715716
"Pow", name,
716717
null,
@@ -719,7 +720,7 @@ public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
719720
return results[0];
720721
}
721722

722-
var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x_tensor, y_tensor });
723+
var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y });
723724

724725
return _op.output;
725726
});

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
1111
<Company>SciSharp STACK</Company>
1212
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
13-
<Copyright>Apache 2.0</Copyright>
13+
<Copyright>Apache 2.0, Haiping Chen $([System.DateTime]::UtcNow.ToString(yyyy))</Copyright>
1414
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
1515
<RepositoryType>git</RepositoryType>
1616
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl>

src/TensorFlowNET.Core/Training/AdamOptimizer.cs

+10-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.9
5252
_dtype = dtype;
5353
}
5454

55+
public override Operation _apply_sparse(IndexedSlices grad, ResourceVariable var)
56+
{
57+
return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
58+
{
59+
return state_ops.scatter_add(x, i, v, use_locking: _use_locking);
60+
});
61+
}
62+
5563
public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
5664
{
5765
return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
@@ -91,15 +99,15 @@ private Operation _apply_sparse_shared(Tensor grad, IVariableV1 var, Tensor indi
9199
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
92100
var m = get_slot(var, "m");
93101
var m_scaled_g_values = grad * (1 - beta1_t);
94-
var m_t = state_ops.assign(m.AsTensor(), m.AsTensor() * beta1_t, use_locking: _use_locking);
102+
var m_t = state_ops.assign(m, m.AsTensor() * beta1_t, use_locking: _use_locking);
95103
tf_with(ops.control_dependencies(new[] { m_t }), delegate
96104
{
97105
m_t = scatter_add(m, indices, m_scaled_g_values);
98106
});
99107

100108
var v = get_slot(var, "v");
101109
var v_scaled_g_values = (grad * grad) * (1 - beta2_t);
102-
var v_t = state_ops.assign(v.AsTensor(), v.AsTensor() * beta2_t, use_locking: _use_locking);
110+
var v_t = state_ops.assign(v, v.AsTensor() * beta2_t, use_locking: _use_locking);
103111
tf_with(ops.control_dependencies(new[] { v_t }), delegate
104112
{
105113
v_t = scatter_add(v, indices, v_scaled_g_values);

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

+4-34
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public static Tensor variable_v2(int[] shape, TF_DataType dtype, string name = n
5656
/// <param name="validate_shape"></param>
5757
/// <param name="use_locking"></param>
5858
/// <param name="name"></param>
59-
public static Tensor assign(Tensor @ref, object value,
59+
public static Tensor assign<T>(T @ref, object value,
6060
bool validate_shape = true,
6161
bool use_locking = true,
6262
string name = null)
@@ -74,40 +74,10 @@ public static Tensor assign(Tensor @ref, object value,
7474
return _result[0];
7575
}
7676

77-
public static Tensor assign(RefVariable @ref, object value,
78-
bool validate_shape = true,
79-
bool use_locking = true,
80-
string name = null)
81-
{
82-
var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });
83-
84-
var _result = _op.outputs;
85-
var _inputs_flat = _op.inputs;
86-
87-
var _attrs = new Dictionary<string, object>();
88-
_attrs["T"] = _op.get_attr("T");
89-
_attrs["validate_shape"] = _op.get_attr("validate_shape");
90-
_attrs["use_locking"] = _op.get_attr("use_locking");
91-
92-
return _result[0];
93-
}
94-
95-
public static Tensor assign(ResourceVariable @ref, object value,
96-
bool validate_shape = true,
97-
bool use_locking = true,
98-
string name = null)
77+
public static Tensor assign_add<T>(IVariableV1 @ref, T value, bool use_locking = false, string name = null)
9978
{
100-
var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });
101-
102-
var _result = _op.outputs;
103-
var _inputs_flat = _op.inputs;
104-
105-
var _attrs = new Dictionary<string, object>();
106-
_attrs["T"] = _op.get_attr("T");
107-
_attrs["validate_shape"] = _op.get_attr("validate_shape");
108-
_attrs["use_locking"] = _op.get_attr("use_locking");
109-
110-
return _result[0];
79+
var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking });
80+
return _op.outputs[0];
11181
}
11282

11383
public static Tensor assign_sub(IVariableV1 @ref,

src/TensorFlowNET.Core/Variables/state_ops.cs

+8-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow
2021
{
@@ -54,19 +55,7 @@ public static Tensor assign(Tensor @ref, object value,
5455
return @ref.assign((Tensor)value, name: name);
5556
}
5657

57-
public static Tensor assign(RefVariable @ref, object value,
58-
bool validate_shape = true,
59-
bool use_locking = true,
60-
string name = null)
61-
{
62-
return gen_state_ops.assign(@ref,
63-
value,
64-
validate_shape: validate_shape,
65-
use_locking: use_locking,
66-
name: name);
67-
}
68-
69-
public static Tensor assign(ResourceVariable @ref, object value,
58+
public static Tensor assign<T>(T @ref, object value,
7059
bool validate_shape = true,
7160
bool use_locking = true,
7261
string name = null)
@@ -110,7 +99,12 @@ public static ITensorOrOperation assign_add<T>(IVariableV1 @ref,
11099
T value,
111100
bool use_locking = false,
112101
string name = null)
113-
=> @ref.assign_add(value, use_locking: use_locking, name: name);
102+
{
103+
if(tf.executing_eagerly())
104+
return @ref.assign_add(value, use_locking: use_locking, name: name);
105+
else
106+
return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
107+
}
114108

115109
public static Tensor scatter_add(IVariableV1 @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
116110
{

0 commit comments

Comments
 (0)