diff --git a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs index b72412265..4e2790ab1 100644 --- a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs @@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses { public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc { + private bool _from_logits = false; public SparseCategoricalCrossentropy( bool from_logits = false, string reduction = null, string name = null) : - base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } + base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) + { + _from_logits = from_logits; + } public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) { target = tf.cast(target, dtype: TF_DataType.TF_INT64); - if (!from_logits) + if (!_from_logits) { var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); output = tf.clip_by_value(output, epsilon, 1 - epsilon);