Skip to content

Commit 0b5ab78

Browse files
committed
Accept data as a generator
1 parent 0fe7e24 commit 0b5ab78

File tree

1 file changed

+64
-3
lines changed

1 file changed

+64
-3
lines changed

pygad/kerasga/kerasga.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
import tensorflow.keras
44

55
def model_weights_as_vector(model):
6+
"""
7+
Reshapes the Keras model weight as a vector.
8+
9+
Parameters
10+
----------
11+
model : TYPE
12+
The Keras model.
13+
14+
Returns
15+
-------
16+
TYPE
17+
The weights as a 1D vector.
18+
19+
"""
620
weights_vector = []
721

822
for layer in model.layers: # model.get_weights():
@@ -15,6 +29,22 @@ def model_weights_as_vector(model):
1529
return numpy.array(weights_vector)
1630

1731
def model_weights_as_matrix(model, weights_vector):
32+
"""
33+
Reshapes the PyGAD 1D solution as a Keras weight matrix.
34+
35+
Parameters
36+
----------
37+
model : TYPE
38+
The Keras model.
39+
weights_vector : TYPE
40+
The PyGAD solution as a 1D vector.
41+
42+
Returns
43+
-------
44+
weights_matrix : TYPE
45+
The Keras weights as a matrix.
46+
47+
"""
1848
weights_matrix = []
1949

2050
start = 0
@@ -37,14 +67,45 @@ def model_weights_as_matrix(model, weights_vector):
3767

3868
return weights_matrix
3969

40-
def predict(model, solution, data, verbose=1):
70+
def predict(model,
71+
solution,
72+
data,
73+
batch_size=None,
74+
verbose=0,
75+
steps=None):
76+
"""
77+
Use the PyGAD's solution to make predictions using the Keras model.
78+
79+
Parameters
80+
----------
81+
model : TYPE
82+
The Keras model.
83+
solution : TYPE
84+
A single PyGAD solution as 1D vector.
85+
data : TYPE
86+
The data or a generator.
87+
batch_size : TYPE, optional
88+
The batch size (i.e. number of samples per step or batch). The default is None. Check documentation of the Keras Model.predict() method for more information.
89+
verbose : TYPE, optional
90+
Verbosity mode. The default is 0. Check documentation of the Keras Model.predict() method for more information.
91+
steps : TYPE, optional
92+
The total number of steps (batches of samples). The default is None. Check documentation of the Keras Model.predict() method for more information.
93+
94+
Returns
95+
-------
96+
predictions : TYPE
97+
The Keras model predictions.
98+
99+
"""
41100
# Fetch the parameters of the best solution.
42101
solution_weights = model_weights_as_matrix(model=model,
43102
weights_vector=solution)
44103
_model = tensorflow.keras.models.clone_model(model)
45104
_model.set_weights(solution_weights)
46-
predictions = _model(data)
47-
predictions = predictions.numpy()
105+
predictions = _model.predict(x=data,
106+
batch_size=batch_size,
107+
verbose=verbose,
108+
steps=steps)
48109

49110
return predictions
50111

0 commit comments

Comments
 (0)