File tree 1 file changed +10
-8
lines changed
1 file changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -535,14 +535,16 @@ def eval(self, tokens: Sequence[int]):
535
535
# Save tokens
536
536
self .input_ids [n_past : n_past + n_tokens ] = batch
537
537
# Save logits
538
- rows = n_tokens
539
- cols = self ._n_vocab
540
- offset = (
541
- 0 if self .context_params .logits_all else n_tokens - 1
542
- ) # NOTE: Only save the last token logits if logits_all is False
543
- self .scores [n_past + offset : n_past + n_tokens , :].reshape (- 1 )[
544
- :
545
- ] = self ._ctx .get_logits ()[offset * cols : rows * cols ]
538
+ if self .context_params .logits_all :
539
+ rows = n_tokens
540
+ cols = self ._n_vocab
541
+ logits = self ._ctx .get_logits ()[: rows * cols ]
542
+ self .scores [n_past : n_past + n_tokens , :].reshape (- 1 )[: :] = logits
543
+ else :
544
+ rows = 1
545
+ cols = self ._n_vocab
546
+ logits = self ._ctx .get_logits ()[: rows * cols ]
547
+ self .scores [n_past + n_tokens - 1 , :].reshape (- 1 )[: :] = logits
546
548
# Update n_tokens
547
549
self .n_tokens += n_tokens
548
550
You can’t perform that action at this time.
0 commit comments