Skip to content

Commit 8649d76

Browse files
committed
fix: segfault when logits_all=False. Closes #1319
1 parent f96de6d commit 8649d76

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

llama_cpp/llama.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,16 @@ def eval(self, tokens: Sequence[int]):
535535
# Save tokens
536536
self.input_ids[n_past : n_past + n_tokens] = batch
537537
# Save logits
538-
rows = n_tokens
539-
cols = self._n_vocab
540-
offset = (
541-
0 if self.context_params.logits_all else n_tokens - 1
542-
) # NOTE: Only save the last token logits if logits_all is False
543-
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
544-
:
545-
] = self._ctx.get_logits()[offset * cols : rows * cols]
538+
if self.context_params.logits_all:
539+
rows = n_tokens
540+
cols = self._n_vocab
541+
logits = self._ctx.get_logits()[: rows * cols]
542+
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
543+
else:
544+
rows = 1
545+
cols = self._n_vocab
546+
logits = self._ctx.get_logits()[: rows * cols]
547+
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
546548
# Update n_tokens
547549
self.n_tokens += n_tokens
548550

0 commit comments

Comments
 (0)