diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
index 0ced407a8..318b8b142 100644
--- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
@@ -102,11 +102,12 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat
{
throw new ValueError("\'image\' must be fully defined.");
}
- for (int x = 1; x < 4; x++)
+ var dims = image_shape["-3:"];
+ foreach (var dim in dims.dims)
{
- if (image_shape.dims[x] == 0)
+ if (dim == 0)
{
- throw new ValueError(String.Format("inner 3 dims of \'image.shape\' must be > 0: {0}", image_shape));
+ throw new ValueError("inner 3 dimensions of \'image\' must be > 0: " + image_shape);
}
}
@@ -965,9 +966,9 @@ public static Tensor per_image_standardization(Tensor image)
if (Array.Exists(new[] { dtypes.float16, dtypes.float32 }, orig_dtype => orig_dtype == orig_dtype))
image = convert_image_dtype(image, dtypes.float32);
- var num_pixels_ = array_ops.shape(image).dims;
- num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray();
- Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_));
+ var x = image.shape["-3:"];
+ var num_pixels = math_ops.reduce_prod(x);
+
Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true);
var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true);
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
index f42d12cde..377ac4de7 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
@@ -8,6 +8,37 @@ public partial class Preprocessing
{
public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" };
+ ///
+ /// Function that calculates the classification statistics for a given array of classified data.
+ /// The function takes an array of classified data as input and returns a dictionary containing the count and percentage of each class in the input array.
+ /// This function can be used to analyze the distribution of classes in a dataset or to evaluate the performance of a classification model.
+ ///
+ ///
+ /// code from copilot
+ ///
+ ///
+ ///
+ Dictionary get_classification_statistics(int[] label_ids, string[] label_class_names)
+ {
+ var countDict = label_ids.GroupBy(x => x)
+ .ToDictionary(g => g.Key, g => g.Count());
+ var totalCount = label_ids.Length;
+ var ratioDict = label_class_names.ToDictionary(name => name,
+ name =>
+ (double)(countDict.ContainsKey(Array.IndexOf(label_class_names, name))
+ ? countDict[Array.IndexOf(label_class_names, name)] : 0)
+ / totalCount);
+
+ print("Classification statistics:");
+ foreach (string labelName in label_class_names)
+ {
+ double ratio = ratioDict[labelName];
+ print($"{labelName}: {ratio * 100:F2}%");
+ }
+
+ return ratioDict;
+ }
+
///
/// Generates a `tf.data.Dataset` from image files in a directory.
/// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory
@@ -53,6 +84,7 @@ public IDatasetV2 image_dataset_from_directory(string directory,
follow_links: follow_links);
(image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset);
+ get_classification_statistics(label_list, class_name_list);
var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
if (shuffle)
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
index eaa762d89..232f81eb5 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
@@ -9,6 +9,7 @@ public partial class Preprocessing
///
/// 图片路径转为数据处理用的dataset
+ /// 通常用于预测时读取图片
///
///
///