Skip to content

Commit 2a17b9c

Browse files
committed
add RefVariable override of state_ops.assign #271
1 parent 967fc43 commit 2a17b9c

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

src/TensorFlowNET.Core/Train/AdamOptimizer.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indi
4646
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
4747
var m = get_slot(var, "m");
4848
var m_scaled_g_values = grad * (1 - beta1_t);
49-
var mul = m * beta1_t;
50-
var m_t = state_ops.assign(m, mul, use_locking: _use_locking);
49+
var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
5150
with(ops.control_dependencies(new[] { m_t }), delegate
5251
{
5352
m_t = scatter_add(m, indices, m_scaled_g_values);

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,26 @@ public static Tensor assign(Tensor @ref, object value,
6767
return _result[0];
6868
}
6969

70+
public static Tensor assign(RefVariable @ref, object value,
71+
bool validate_shape = true,
72+
bool use_locking = true,
73+
string name = null)
74+
{
75+
var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });
76+
77+
var _result = _op.outputs;
78+
var _inputs_flat = _op.inputs;
79+
80+
var _attrs = new Dictionary<string, object>();
81+
_attrs["T"] = _op.get_attr("T");
82+
_attrs["validate_shape"] = _op.get_attr("validate_shape");
83+
_attrs["use_locking"] = _op.get_attr("use_locking");
84+
85+
_execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name);
86+
87+
return _result[0];
88+
}
89+
7090
public static Tensor assign_sub(RefVariable @ref,
7191
Tensor value,
7292
bool use_locking = false,

src/TensorFlowNET.Core/Variables/state_ops.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ public static Tensor assign(Tensor @ref, object value,
4040
//return @ref.assign(value, name: name);
4141
}
4242

43+
public static Tensor assign(RefVariable @ref, object value,
44+
bool validate_shape = true,
45+
bool use_locking = true,
46+
string name = null)
47+
{
48+
return gen_state_ops.assign(@ref,
49+
value,
50+
validate_shape: validate_shape,
51+
use_locking: use_locking,
52+
name: name);
53+
}
54+
4355
public static Tensor assign_sub(RefVariable @ref,
4456
Tensor value,
4557
bool use_locking = false,

0 commit comments

Comments
 (0)