Skip to content

Feat: support DTensor when saving #3042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

S1ro1
Copy link
Member

@S1ro1 S1ro1 commented May 1, 2025

This enables transformers to use save_pretrained when model was shared with DTensor. Shouldn't break anything as this just failed before.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1 S1ro1 changed the title tmp: return tensor.nbytes for get_torch_storage_size Feat: support DTensor when saving May 2, 2025
@S1ro1 S1ro1 marked this pull request as ready for review May 2, 2025 12:05
Copy link
Contributor

@hanouticelina hanouticelina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @S1ro1 for the PR! I left a comment about the storage size computation of a DTensor

Comment on lines +765 to +772
try:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
# this returns the size of the FULL tensor in bytes
return tensor.nbytes
except ImportError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with DTensor, but if the tensor is indeed a DTensor and the import fails line 766, would it be okay to fallback to tensor.untyped_storage().nbytes() ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, will have to test locally, but I'm pretty sure that would fail on has no method untyped_storage. But this import shouldn't ever fail if the tensor is DTensor. It's wrapped in try/except to avoid version checking as DTensor is torch >= 2.1 (ish).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTensor is torch >= 2.1 (ish)

okay then all good!

Copy link
Contributor

@hanouticelina hanouticelina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! @S1ro1 could you add a simple test for get_torch_storage_size with DTensor in https://github.com/huggingface/huggingface_hub/blob/main/tests/test_serialization.py if possible?

this one should be enough:

@requires("torch")
def test_get_torch_storage_size_dtensor():
    import torch
    import torch.distributed as dist
    from torch.distributed.device_mesh import init_device_mesh
    from torch.distributed.tensor import DTensor, Replicate

    if dist.is_available() and not dist.is_initialized():
        dist.init_process_group(
            backend="gloo",
            store=dist.HashStore(),
            rank=0,
            world_size=1,
        )

    mesh = init_device_mesh("cpu", (1,))
    local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
    dt = DTensor.from_local(local, mesh, [Replicate()])

    assert get_torch_storage_size(dt) == 5 * 2

(written with the help of pytorch documentation and Claude)

@S1ro1
Copy link
Member Author

S1ro1 commented May 6, 2025

Yes, will add something similar, sure. I suppose testing multi-process (Shard()) would be also nice, I can try if I can throw something together with subprocess and torchrun

@S1ro1 S1ro1 force-pushed the transformers-save-dtensor branch from 2768611 to 94abfb4 Compare May 9, 2025 14:24
@S1ro1
Copy link
Member Author

S1ro1 commented May 9, 2025

@hanouticelina I've added the test as suggested, it's probably not worth to add more complex test cases as those require torch.distributed.run which would result in having to write the source of the test as string and then call it with subprocess which IMO is not worth.
LMK your thoughts, except of this it should be good to merge now.

Test fails seem to be unrelated

@S1ro1 S1ro1 force-pushed the transformers-save-dtensor branch from a9589e1 to f2f660c Compare May 9, 2025 14:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants