Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c825d12

Browse files
Seppo Enarviafrozenator
Seppo Enarvi
authored andcommitted
Fix decoding in prepend mode (#1726)
* Create an integer problem_0_steps variable. * Save inputs to the feature "partial_targets" when prepend_mode is not "none". * Removed a second call to update_hparams_for_universal_transformer(). Fixes hyperparameter sets universal_transformer_big and universal_transformer_base_tpu. * Fix a bug to make partial targets work for beam size > 1 The dimension of the multiplication of the partial targets was wrong: (a, b, c, d) --> (a, b, c, d, a, b, c, d) Correct multiplication needs to be: (a, b, c, d) --> (a, a, b, b, c, c, d, d) This is because it is (batch_size * beam_size) instead of (beam_size * batch_size). Basically, tf.tile needs to be replaced by tf.repeat which is introduced in tf 1.15. This is a workaround for tf 1.14.
1 parent 67ddb40 commit c825d12

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

tensor2tensor/models/research/universal_transformer.py

-2
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ def universal_transformer_base():
458458
@registry.register_hparams
459459
def universal_transformer_base_tpu():
460460
hparams = universal_transformer_base()
461-
hparams = update_hparams_for_universal_transformer(hparams)
462461
transformer.update_hparams_for_tpu(hparams)
463462
hparams.add_step_timing_signal = False
464463
return hparams
@@ -467,7 +466,6 @@ def universal_transformer_base_tpu():
467466
@registry.register_hparams
468467
def universal_transformer_big():
469468
hparams = universal_transformer_base()
470-
hparams = update_hparams_for_universal_transformer(hparams)
471469
hparams.hidden_size = 2048
472470
hparams.filter_size = 8192
473471
return hparams

tensor2tensor/models/transformer.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -863,9 +863,15 @@ def symbols_to_logits_fn(ids, i, cache):
863863
vocab_size = tf.shape(ret)[1]
864864

865865
def forced_logits():
866+
# Workaround for: tf.one_hot(
867+
# tf.repeat(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
868+
# -1e9)
869+
# Can be replaced by the above in future versions (from tf 1.15).
866870
return tf.one_hot(
867-
tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
868-
-1e9)
871+
tf.reshape(tf.tile(
872+
tf.reshape(partial_targets[:, i], [-1, 1]),
873+
[1, beam_size]), [-1]),
874+
vocab_size, 0.0, -1e9)
869875

870876
ret = tf.cond(
871877
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
@@ -1168,9 +1174,6 @@ def fast_decode(encoder_output,
11681174
"scores": decoding log probs from the beam search,
11691175
None if using greedy decoding (beam_size=1)
11701176
}
1171-
1172-
Raises:
1173-
NotImplementedError: If beam size > 1 with partial targets.
11741177
"""
11751178
if encoder_output is not None:
11761179
batch_size = common_layers.shape_list(encoder_output)[0]

tensor2tensor/utils/decoding.py

+14
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,13 @@ def _interactive_input_tensor_to_features_dict(feature_map, hparams):
927927
features["decode_length"] = (
928928
IMAGE_DECODE_LENGTH if input_is_image else inputs[1])
929929
features["inputs"] = x
930+
# Save inputs to "partial_targets" when prepending inputs to targets. Also
931+
# keep "inputs" as some models crash if they don't exist.
932+
if getattr(hparams, "prepend_mode", "none") != "none":
933+
shape = tf.shape(x)
934+
partial_targets = tf.reshape(x, [shape[0], shape[1]])
935+
partial_targets = tf.pad(partial_targets, [[0, 0], [0, 1]])
936+
features["partial_targets"] = partial_targets
930937
return features
931938

932939

@@ -957,6 +964,13 @@ def _decode_input_tensor_to_features_dict(feature_map, hparams):
957964
features["decode_length"] = (
958965
IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
959966
features["inputs"] = x
967+
# Save inputs to "partial_targets" when prepending inputs to targets. Also
968+
# keep "inputs" as some models crash if they don't exist.
969+
if getattr(hparams, "prepend_mode", "none") != "none":
970+
shape = tf.shape(x)
971+
partial_targets = tf.reshape(x, [shape[0], shape[1]])
972+
partial_targets = tf.pad(partial_targets, [[0, 0], [0, 1]])
973+
features["partial_targets"] = partial_targets
960974
return features
961975

962976

0 commit comments

Comments
 (0)