Skip to content

[bug] use_sliding_window doesn't work as expected #38002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 of 4 tasks
ZhiyuLi-Nvidia opened this issue May 7, 2025 · 2 comments · May be fixed by #38045
Open
1 of 4 tasks

[bug] use_sliding_window doesn't work as expected #38002

ZhiyuLi-Nvidia opened this issue May 7, 2025 · 2 comments · May be fixed by #38045
Labels

Comments

@ZhiyuLi-Nvidia
Copy link

ZhiyuLi-Nvidia commented May 7, 2025

System Info

  • transformer: main
  • pytorch, cuda: anyversion

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModelForCausalLM
import numpy as np
import torch


MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

input_ids = torch.randint(0, 100, (1, 8192))

with torch.no_grad():
    output_raw = model(input_ids)

# correct situaion: should be the same as the original model since use_sliding_window is False
model_no_sliding = AutoModelForCausalLM.from_pretrained(MODEL_NAME, sliding_window=None)

with torch.no_grad():
    output_non_sliding = model_no_sliding(input_ids)


np.testing.assert_allclose(output_raw.logits[:, :4096], output_non_sliding.logits[:, :4096])

# wrong: the logits are unexpectedly different with sliding_window=4096
np.testing.assert_allclose(output_raw.logits[:, 4096:], output_non_sliding.logits[:, 4096:])

Expected behavior

description

What is expected:

  "sliding_window": 4096,
  "use_sliding_window": false,

use_sliding_window is set as false in deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B here. We do expect sliding window is disabled. In other words, we should expect the same results even with different sliding_window.

However, the results are different in the repro script.

Root cause

Attention Mask is changed according to sliding_window without respect on use_sliding_window.

if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

If we add some printing under this conditional block, we can clearly see attention mask is changed even with use_sliding_window=false

@Rocketknight1
Copy link
Member

Not sure who's the right code owner here - cc @gante @zucchini-nlp, but feel free to tag Arthur instead if you think he's more appropriate!

@zucchini-nlp
Copy link
Member

Nice catch! Indeed I think we should be checking both values, given that in attention layer we pass sliding window only when use_sliding_window. I can open a PR for that 🤗

@zucchini-nlp zucchini-nlp linked a pull request May 9, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants