Skip to content

Commit 7cd8292

Browse files
committed
fix per_image_standardization run bug
1 parent ed1a8d2 commit 7cd8292

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

+7-6
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,12 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat
102102
{
103103
throw new ValueError("\'image\' must be fully defined.");
104104
}
105-
for (int x = 1; x < 4; x++)
105+
var dims = image_shape["-3:"];
106+
foreach (var dim in dims.dims)
106107
{
107-
if (image_shape.dims[x] == 0)
108+
if (dim == 0)
108109
{
109-
throw new ValueError(String.Format("inner 3 dims of \'image.shape\' must be > 0: {0}", image_shape));
110+
throw new ValueError("inner 3 dimensions of \'image\' must be > 0: " + image_shape);
110111
}
111112
}
112113

@@ -965,9 +966,9 @@ public static Tensor per_image_standardization(Tensor image)
965966
if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype))
966967
image = convert_image_dtype(image, dtypes.float32);
967968

968-
var num_pixels_ = array_ops.shape(image).dims;
969-
num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray();
970-
Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_));
969+
var x = image.shape["-3:"];
970+
var num_pixels = math_ops.reduce_prod(x);
971+
971972
Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true);
972973

973974
var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true);

0 commit comments

Comments
 (0)