@@ -46,6 +46,16 @@ def is_wrapper_tensor_subclass_available():
46
46
return False
47
47
48
48
49
+ def is_dtensor_available ():
50
+ try :
51
+ from torch .distributed .device_mesh import init_device_mesh # type: ignore[import] # noqa: F401
52
+ from torch .distributed .tensor import DTensor # type: ignore[import] # noqa: F401
53
+
54
+ return True
55
+ except ImportError :
56
+ return False
57
+
58
+
49
59
@pytest .fixture
50
60
def dummy_state_dict () -> Dict [str , List [int ]]:
51
61
return {
@@ -250,6 +260,33 @@ def test_get_torch_storage_size():
250
260
assert get_torch_storage_size (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .float16 )) == 5 * 2
251
261
252
262
263
+ @requires ("torch" )
264
+ @pytest .mark .skipif (not is_dtensor_available (), reason = "requires torch with dtensor available" )
265
+ def test_get_torch_storage_size_dtensor ():
266
+ # testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough
267
+ import torch
268
+ import torch .distributed as dist
269
+ from torch .distributed .device_mesh import init_device_mesh
270
+ from torch .distributed .tensor import DTensor , Replicate
271
+
272
+ if dist .is_available () and not dist .is_initialized ():
273
+ dist .init_process_group (
274
+ backend = "gloo" ,
275
+ store = dist .HashStore (),
276
+ rank = 0 ,
277
+ world_size = 1 ,
278
+ )
279
+
280
+ mesh = init_device_mesh ("cpu" , (1 ,))
281
+ local = torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .float16 )
282
+ dt = DTensor .from_local (local , mesh , [Replicate ()])
283
+
284
+ assert get_torch_storage_size (dt ) == 5 * 2
285
+
286
+ if dist .is_initialized ():
287
+ dist .destroy_process_group ()
288
+
289
+
253
290
@requires ("torch" )
254
291
@pytest .mark .skipif (not is_wrapper_tensor_subclass_available (), reason = "requires torch 2.1 or higher" )
255
292
def test_get_torch_storage_size_wrapper_tensor_subclass ():
0 commit comments