Skip to content

Commit dc18cda

Browse files
committed
efficientvit (mit) msa attention q/k/v ops need to be in float32 to train w/o NaN
1 parent e6aeb91 commit dc18cda

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

timm/models/efficientvit_mit.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,13 @@ def forward(self, x):
250250
k = self.kernel_func(k)
251251
v = F.pad(v, (0, 1), mode="constant", value=1.)
252252

253-
kv = k.transpose(-1, -2) @ v
254-
out = q @ kv
255-
out = out[..., :-1] / (out[..., -1:] + self.eps)
253+
dtype = v.dtype
254+
q, k, v = q.float(), k.float(), v.float()
255+
with torch.amp.autocast(device_type=v.device.type, enabled=False):
256+
kv = k.transpose(-1, -2) @ v
257+
out = q @ kv
258+
out = out[..., :-1] / (out[..., -1:] + self.eps)
259+
out = out.to(dtype)
256260

257261
# final projection
258262
out = out.transpose(-1, -2).reshape(B, -1, H, W)

0 commit comments

Comments
 (0)