Skip to content

Commit dfd9dd0

Browse files
authored
Merge pull request #1126 from lingbai-kong/parse_imdb
Add pad preprocessing for `imdb` dataset
2 parents 3acfc1d + 4efa0a8 commit dfd9dd0

File tree

1 file changed

+22
-2
lines changed
  • src/TensorFlowNET.Keras/Datasets

1 file changed

+22
-2
lines changed

src/TensorFlowNET.Keras/Datasets/Imdb.cs

+22-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ public DatasetPass load_data(string path = "imdb.npz",
4040
int oov_char= 2,
4141
int index_from = 3)
4242
{
43+
if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned.");
44+
4345
var dst = Download();
4446

4547
var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
@@ -51,7 +53,7 @@ public DatasetPass load_data(string path = "imdb.npz",
5153
x_train_string[i] = lines[i].Substring(2);
5254
}
5355

54-
var x_train = np.array(x_train_string);
56+
var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen);
5557

5658
File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
5759
var x_test_string = new string[lines.Length];
@@ -62,7 +64,7 @@ public DatasetPass load_data(string path = "imdb.npz",
6264
x_test_string[i] = lines[i].Substring(2);
6365
}
6466

65-
var x_test = np.array(x_test_string);
67+
var x_test = keras.preprocessing.sequence.pad_sequences(PraseData(x_test_string), maxlen: maxlen);
6668

6769
return new DatasetPass
6870
{
@@ -93,5 +95,23 @@ string Download()
9395
return dst;
9496
// return Path.Combine(dst, file_name);
9597
}
98+
99+
protected IEnumerable<int[]> PraseData(string[] x)
100+
{
101+
var data_list = new List<int[]>();
102+
for (int i = 0; i < len(x); i++)
103+
{
104+
var list_string = x[i];
105+
var cleaned_list_string = list_string.Replace("[", "").Replace("]", "").Replace(" ", "");
106+
string[] number_strings = cleaned_list_string.Split(',');
107+
int[] numbers = new int[number_strings.Length];
108+
for (int j = 0; j < number_strings.Length; j++)
109+
{
110+
numbers[j] = int.Parse(number_strings[j]);
111+
}
112+
data_list.Add(numbers);
113+
}
114+
return data_list;
115+
}
96116
}
97117
}

0 commit comments

Comments
 (0)