Skip to content

Commit 44bdddc

Browse files
authored
Merge pull request #1205 from Wanglongzhi2001/fix_boolean_mask
fix: fix the bug of boolean_mask
2 parents 079b9a3 + 4e42d7f commit 44bdddc

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ public static Tensor _transpose_batch_time(Tensor x)
428428
return x;
429429

430430
var x_rank = array_ops.rank(x);
431-
var con1 = new object[]
431+
var con1 = new Tensor[]
432432
{
433-
new []{1, 0 },
433+
new Tensor(new int[]{0, 2}),
434434
math_ops.range(2, x_rank)
435435
};
436436
var x_t = array_ops.transpose(x, array_ops.concat(con1, 0));

src/TensorFlowNET.Core/Operations/array_ops.cs

+9-4
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo
166166
throw new ValueError("mask cannot be scalar.");
167167

168168
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 }));
169+
if (leading_size.rank == 0)
170+
{
171+
leading_size = expand_dims(leading_size, 0);
172+
}
173+
169174
var shape1 = concat(new[]
170175
{
171176
shape(tensor_tensor)[$":{axis}"],
@@ -185,7 +190,7 @@ public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boo
185190

186191
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
187192
{
188-
var indices = squeeze(where(mask), axis: new[] { 1 });
193+
var indices = squeeze(where_v2(mask), axis: new[] { 1 });
189194
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis));
190195
}
191196

@@ -940,12 +945,12 @@ public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
940945
/// <returns></returns>
941946
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
942947
{
943-
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
948+
return gen_array_ops.concat_v2(values, axis, name: name);
944949
}
945950

946-
public static Tensor concat(object[] values, int axis, string name = "concat")
951+
public static Tensor concat(Tensor[] values, Axis axis, string name = "concat")
947952
{
948-
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis));
953+
return gen_array_ops.concat_v2(values, axis, name: name);
949954
}
950955

951956
/// <summary>

src/TensorFlowNET.Core/Operations/nn_ops.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ private static Tensor _flatten_outer_dims(Tensor logits)
287287
new[] { math_ops.subtract(rank, 1) },
288288
new[] { constant_op.constant(1) });
289289

290-
var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
290+
var ops = array_ops.concat(new Tensor[] { new Tensor(new int[] {1}), last_dim_size }, 0);
291291
var output = array_ops.reshape(logits, ops);
292292

293293
// Set output shape if known.

test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Linq;
55
using static Tensorflow.Binding;
6+
using Tensorflow;
67

78
namespace TensorFlowNET.UnitTest.Basics
89
{
@@ -60,14 +61,14 @@ public void batch_to_space_nd()
6061
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
6162
}
6263

63-
[TestMethod, Ignore]
64+
[TestMethod]
6465
public void boolean_mask()
6566
{
67+
if (!tf.executing_eagerly())
68+
tf.enable_eager_execution();
6669
var tensor = new[] { 0, 1, 2, 3 };
6770
var mask = np.array(new[] { true, false, true, false });
6871
var masked = tf.boolean_mask(tensor, mask);
69-
var sess = tf.Session();
70-
var result = sess.run(masked);
7172
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
7273
}
7374
}

0 commit comments

Comments
 (0)