@@ -273,13 +273,11 @@ def main(args):
273
273
pp_rank = pp_mesh .get_local_rank ()
274
274
tp_group = tp_mesh .get_group ()
275
275
pp_group = pp_mesh .get_group ()
276
- pp_group_size = pp_group .size ()
277
- tp_group_size = tp_group .size ()
278
- logger .info (f"{ pp_group_size = } , { tp_group_size = } " )
276
+ logger .info (f"{ pp_degree = } , { tp_degree = } " )
279
277
280
278
# Convenience variables
281
279
first_pp_rank = 0
282
- last_pp_rank = pp_group_size - 1
280
+ last_pp_rank = pp_degree - 1
283
281
284
282
# Assuming same number of GPUs per node
285
283
device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
@@ -297,18 +295,22 @@ def main(args):
297
295
if rank == 0 :
298
296
logger .info (f"Model: { model } " )
299
297
300
- mbs = 1 # number of micro- batches
301
- mb_size = 4 # micro-batch size
302
- batch_size = mbs * mb_size # total batch size
303
-
298
+ # Batch size. Since we push batches dynamically through the pipeline rather
299
+ # than chunking them, this is effectively micro-batch size in pipeline
300
+ # sense. Thus it is interchangeable with micro- batch size below.
301
+ batch_size = 4
304
302
seqlen_prefill = 1024 # sequence length
305
303
dim = 4096 # embedding dimension
306
304
307
305
# Setup KV caches (after model distribution)
308
- # TODO: the setting below only works for 1 micro-batch case. To support
309
- # multiple micro-batches, we need the KV cache in the model to be aware of
310
- # the number of micro-batches and the current micro-batch index.
311
- model .setup_caches (mb_size , seqlen_prefill )
306
+ # The number of cache lanes is the same as the maximum number of
307
+ # micro-batches that can be "in flight" in parallel -- imagine each
308
+ # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
309
+ # When decoding is done for certain micro-batches, we can reuse the KV cache
310
+ # lanes.
311
+ # TODO: bump up the lane count
312
+ pipeline_lanes = 1
313
+ model .setup_caches (batch_size , seqlen_prefill , cache_lanes = pipeline_lanes )
312
314
313
315
# Load weights
314
316
logger .info (f"Loading weights for { pp_rank = } on { device = } " )
@@ -317,7 +319,7 @@ def main(args):
317
319
model .to (device )
318
320
319
321
logger .info (
320
- f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
322
+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
321
323
)
322
324
323
325
# info on stage size and params
@@ -330,17 +332,16 @@ def main(args):
330
332
331
333
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
332
334
input_pos = torch .arange (seqlen_prefill , device = device )
333
- model .setup_input_pos (input_pos )
334
335
model .eval ()
335
336
336
337
# Helper function to get example inputs and outputs for the stages.
337
338
def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
338
- mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
339
+ mb_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
339
340
activation = torch .rand (
340
- mb_size , seqlen , dim , device = device , dtype = model_dtype
341
+ batch_size , seqlen , dim , device = device , dtype = model_dtype
341
342
)
342
343
logits = torch .rand (
343
- mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
344
+ batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
344
345
)
345
346
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation ,)
346
347
example_outputs = (logits if pp_rank == last_pp_rank else activation ,)
@@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
358
359
output_args = example_outputs ,
359
360
group = pp_group ,
360
361
)
361
- # create schedule
362
- prefill_schedule = ScheduleGPipe (prefill_stage , mbs )
362
+
363
+ # Create schedule
364
+ # Number of micro-batches for the schedule is 1, because each step() call we
365
+ # only push 1 micro-batch into the pipeline. But we can continuously push
366
+ # new micro-batches into the pipeline as they arrive, achieving same
367
+ # pipelining effect.
368
+ prefiller = ScheduleGPipe (prefill_stage , 1 )
363
369
364
370
prompt = [
365
371
"What is a computer?" ,
@@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
388
394
s = set (prompt_lengths )
389
395
assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
390
396
391
- # with CUDATrackTime() as timer:
392
397
# Need these global ids due to the API definition of dist.send and recv
393
398
first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
394
399
last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
@@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401
406
num_tokens = 40
402
407
403
408
# Prefill phase
404
- # Run context input through pipeline, in 1 step
405
- with torch .no_grad ():
409
+ # Run context input through pipeline
410
+ # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
411
+ lane = 0
412
+ kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
413
+ with torch .no_grad (), CUDATrackTime () as timer :
406
414
if pp_rank == first_pp_rank :
407
- output = prefill_schedule .step (padded_sequence )
415
+ output = prefiller .step (padded_sequence , ** kwargs )
408
416
elif pp_rank == last_pp_rank :
409
- output = prefill_schedule .step ()
417
+ output = prefiller .step (** kwargs )
410
418
else : # middle pp ranks
411
- prefill_schedule .step ()
419
+ prefiller .step (** kwargs )
420
+
421
+ logger .info (
422
+ f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
423
+ )
412
424
413
425
# Decode the output -- first generated token
414
426
if pp_rank == last_pp_rank :
@@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
430
442
# seqlen = 1 now
431
443
seqlen_decode = 1
432
444
input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
433
- model .setup_input_pos (input_pos )
434
445
435
446
# Create decode stage
436
447
logger .info (f"Creating pipeline stage for decode { pp_rank = } , { pp_degree = } " )
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
445
456
group = pp_group ,
446
457
)
447
458
# create schedule
448
- decode_schedule = ScheduleGPipe (decode_stage , mbs )
459
+ decorder = ScheduleGPipe (decode_stage , 1 )
449
460
450
461
# Decoding
451
- with torch .no_grad ():
462
+ with torch .no_grad (), CUDATrackTime () as timer :
452
463
for step in range (num_tokens - 1 ):
464
+ kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
453
465
# sendrecv between last and first ranks, only if:
454
466
# first_pp_rank != last_pp_rank.
455
467
if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
467
479
468
480
# Run data through pipeline
469
481
if pp_rank == first_pp_rank :
470
- output = decode_schedule .step (new_token )
482
+ output = decorder .step (new_token , ** kwargs )
471
483
elif pp_rank == last_pp_rank :
472
- output = decode_schedule .step ()
484
+ output = decorder .step (** kwargs )
473
485
else : # middle pp ranks
474
- decode_schedule .step ()
486
+ decorder .step (** kwargs )
475
487
476
488
# Decode the output
477
489
if pp_rank == last_pp_rank :
@@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
491
503
) # decode_results[i][0]
492
504
493
505
input_pos += 1
494
- model .setup_input_pos (input_pos )
506
+
507
+ logger .info (
508
+ f"{ color .green } Decoding time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
509
+ )
495
510
496
511
# Display the decoding results
497
512
0 commit comments