Skip to content

Commit 0cc25fb

Browse files
committed
Add a function(get_classification_statistics) to count the number of label categories for the image_dataset_from_directory method.
1 parent 7cd8292 commit 0cc25fb

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

+32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,37 @@ public partial class Preprocessing
88
{
99
public static string[] WHITELIST_FORMATS = new[] { ".bmp", ".gif", ".jpeg", ".jpg", ".png" };
1010

11+
/// <summary>
12+
/// Function that calculates the classification statistics for a given array of classified data.
13+
/// The function takes an array of classified data as input and returns a dictionary containing the count and percentage of each class in the input array.
14+
/// This function can be used to analyze the distribution of classes in a dataset or to evaluate the performance of a classification model.
15+
/// </summary>
16+
/// <remarks>
17+
/// code from copilot
18+
/// </remarks>
19+
/// <param name="label_ids"></param>
20+
/// <param name="label_class_names"></param>
21+
Dictionary<string, double> get_classification_statistics(int[] label_ids, string[] label_class_names)
22+
{
23+
var countDict = label_ids.GroupBy(x => x)
24+
.ToDictionary(g => g.Key, g => g.Count());
25+
var totalCount = label_ids.Length;
26+
var ratioDict = label_class_names.ToDictionary(name => name,
27+
name =>
28+
(double)(countDict.ContainsKey(Array.IndexOf(label_class_names, name))
29+
? countDict[Array.IndexOf(label_class_names, name)] : 0)
30+
/ totalCount);
31+
32+
print("Classification statistics:");
33+
foreach (string labelName in label_class_names)
34+
{
35+
double ratio = ratioDict[labelName];
36+
print($"{labelName}: {ratio * 100:F2}%");
37+
}
38+
39+
return ratioDict;
40+
}
41+
1142
/// <summary>
1243
/// Generates a `tf.data.Dataset` from image files in a directory.
1344
/// https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory
@@ -53,6 +84,7 @@ public IDatasetV2 image_dataset_from_directory(string directory,
5384
follow_links: follow_links);
5485

5586
(image_paths, label_list) = keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset);
87+
get_classification_statistics(label_list, class_name_list);
5688

5789
var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation);
5890
if (shuffle)

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public partial class Preprocessing
99

1010
/// <summary>
1111
/// 图片路径转为数据处理用的dataset
12+
/// 通常用于预测时读取图片
1213
/// </summary>
1314
/// <param name="image_paths"></param>
1415
/// <param name="image_size"></param>

0 commit comments

Comments
 (0)