-
Notifications
You must be signed in to change notification settings - Fork 701
Feat: support DTensor when saving #3042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @S1ro1 for the PR! I left a comment about the storage size computation of a DTensor
try: | ||
from torch.distributed.tensor import DTensor | ||
|
||
if isinstance(tensor, DTensor): | ||
# this returns the size of the FULL tensor in bytes | ||
return tensor.nbytes | ||
except ImportError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with DTensor
, but if the tensor is indeed a DTensor
and the import fails line 766, would it be okay to fallback to tensor.untyped_storage().nbytes()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure, will have to test locally, but I'm pretty sure that would fail on has no method untyped_storage
. But this import shouldn't ever fail if the tensor is DTensor
. It's wrapped in try/except to avoid version checking as DTensor is torch >= 2.1 (ish).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DTensor is torch >= 2.1 (ish)
okay then all good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! @S1ro1 could you add a simple test for get_torch_storage_size
with DTensor
in https://github.com/huggingface/huggingface_hub/blob/main/tests/test_serialization.py if possible?
this one should be enough:
@requires("torch")
def test_get_torch_storage_size_dtensor():
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate
if dist.is_available() and not dist.is_initialized():
dist.init_process_group(
backend="gloo",
store=dist.HashStore(),
rank=0,
world_size=1,
)
mesh = init_device_mesh("cpu", (1,))
local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
dt = DTensor.from_local(local, mesh, [Replicate()])
assert get_torch_storage_size(dt) == 5 * 2
(written with the help of pytorch documentation and Claude)
Yes, will add something similar, sure. I suppose testing multi-process ( |
2768611
to
94abfb4
Compare
@hanouticelina I've added the test as suggested, it's probably not worth to add more complex test cases as those require Test fails seem to be unrelated |
a9589e1
to
f2f660c
Compare
This enables transformers to use
save_pretrained
when model was shared withDTensor
. Shouldn't break anything as this just failed before.