Skip to content

Commit f2f660c

Browse files
committed
Feat: tests
1 parent 2525677 commit f2f660c

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

tests/test_serialization.py

+37
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def is_wrapper_tensor_subclass_available():
4646
return False
4747

4848

49+
def is_dtensor_available():
50+
try:
51+
from torch.distributed.device_mesh import init_device_mesh # type: ignore[import] # noqa: F401
52+
from torch.distributed.tensor import DTensor # type: ignore[import] # noqa: F401
53+
54+
return True
55+
except ImportError:
56+
return False
57+
58+
4959
@pytest.fixture
5060
def dummy_state_dict() -> Dict[str, List[int]]:
5161
return {
@@ -250,6 +260,33 @@ def test_get_torch_storage_size():
250260
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2
251261

252262

263+
@requires("torch")
264+
@pytest.mark.skipif(not is_dtensor_available(), reason="requires torch with dtensor available")
265+
def test_get_torch_storage_size_dtensor():
266+
# testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough
267+
import torch
268+
import torch.distributed as dist
269+
from torch.distributed.device_mesh import init_device_mesh
270+
from torch.distributed.tensor import DTensor, Replicate
271+
272+
if dist.is_available() and not dist.is_initialized():
273+
dist.init_process_group(
274+
backend="gloo",
275+
store=dist.HashStore(),
276+
rank=0,
277+
world_size=1,
278+
)
279+
280+
mesh = init_device_mesh("cpu", (1,))
281+
local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
282+
dt = DTensor.from_local(local, mesh, [Replicate()])
283+
284+
assert get_torch_storage_size(dt) == 5 * 2
285+
286+
if dist.is_initialized():
287+
dist.destroy_process_group()
288+
289+
253290
@requires("torch")
254291
@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
255292
def test_get_torch_storage_size_wrapper_tensor_subclass():

0 commit comments

Comments
 (0)