Skip to content

Commit 3006c86

Browse files
authored
Merge pull request #1145 from dogvane/master
Thanks for your contribution.
2 parents ed1a8d2 + 0cc25fb commit 3006c86

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

+7-6
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,12 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat
102102
{
103103
throw new ValueError("\'image\' must be fully defined.");
104104
}
105-
for (int x = 1; x < 4; x++)
105+
var dims = image_shape["-3:"];
106+
foreach (var dim in dims.dims)
106107
{
107-
if (image_shape.dims[x] == 0)
108+
if (dim == 0)
108109
{
109-
throw new ValueError(String.Format("inner 3 dims of \'image.shape\' must be > 0: {0}", image_shape));
110+
throw new ValueError("inner 3 dimensions of \'image\' must be > 0: " + image_shape);
110111
}
111112
}
112113

@@ -965,9 +966,9 @@ public static Tensor per_image_standardization(Tensor image)
965966
if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype))
966967
image = convert_image_dtype(image, dtypes.float32);
967968

968-
var num_pixels_ = array_ops.shape(image).dims;
969-
num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray();
970-
Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_));
969+
var x = image.shape["-3:"];
970+
var num_pixels = math_ops.reduce_prod(x);
971+
971972
Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true);
972973

973974
var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true);

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

+32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,37 @@ public partial class Preprocessing
88
{
99
public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" };
1010

11+
/// <summary>
12+
/// Function that calculates the classification statistics for a given array of classified data.
13+
/// The function takes an array of classified data as input and returns a dictionary containing the count and percentage of each class in the input array.
14+
/// This function can be used to analyze the distribution of classes in a dataset or to evaluate the performance of a classification model.
15+
/// </summary>
16+
/// <remarks>
17+
/// code from copilot
18+
/// </remarks>
19+
/// <param name="label_ids"></param>
20+
/// <param name="label_class_names"></param>
21+
Dictionary<string, double> get_classification_statistics(int[] label_ids, string[] label_class_names)
22+
{
23+
var countDict = label_ids.GroupBy(x => x)
24+
.ToDictionary(g => g.Key, g => g.Count());
25+
var totalCount = label_ids.Length;
26+
var ratioDict = label_class_names.ToDictionary(name => name,
27+
name =>
28+
(double)(countDict.ContainsKey(Array.IndexOf(label_class_names, name))
29+
? countDict[Array.IndexOf(label_class_names, name)] : 0)
30+
/ totalCount);
31+
32+
print("Classification statistics:");
33+
foreach (string labelName in label_class_names)
34+
{
35+
double ratio = ratioDict[labelName];
36+
print($"{labelName}: {ratio * 100:F2}%");
37+
}
38+
39+
return ratioDict;
40+
}
41+
1142
/// <summary>
1243
/// Generates a `tf.data.Dataset` from image files in a directory.
1344
/// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory
@@ -53,6 +84,7 @@ public IDatasetV2 image_dataset_from_directory(string directory,
5384
follow_links: follow_links);
5485

5586
(image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset);
87+
get_classification_statistics(label_list, class_name_list);
5688

5789
var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
5890
if (shuffle)

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public partial class Preprocessing
99

1010
/// <summary>
1111
/// 图片路径转为数据处理用的dataset
12+
/// 通常用于预测时读取图片
1213
/// </summary>
1314
/// <param name="image_paths"></param>
1415
/// <param name="image_size"></param>

0 commit comments

Comments
 (0)