Skip to content

Commit 33333df

Browse files
authored
Merge pull request #1018 from Wanglongzhi2001/master
Fix the bug of non-convergence when use SparseCategoricalCrossentropy
2 parents ef687ae + 5506f00 commit 33333df

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses
44
{
55
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc
66
{
7+
private bool _from_logits = false;
78
public SparseCategoricalCrossentropy(
89
bool from_logits = false,
910
string reduction = null,
1011
string name = null) :
11-
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ }
12+
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name)
13+
{
14+
_from_logits = from_logits;
15+
}
1216

1317
public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
1418
{
1519
target = tf.cast(target, dtype: TF_DataType.TF_INT64);
1620

17-
if (!from_logits)
21+
if (!_from_logits)
1822
{
1923
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
2024
output = tf.clip_by_value(output, epsilon, 1 - epsilon);

0 commit comments

Comments
 (0)