Skip to content

Commit 2cf4016

Browse files
authored
[Distributed] Add lanes to KV cache (#1174)
* [WIP][Distributed] Add lanes to KV cache * Compatibility change * Naming * Remove setup_input_pos * Add timer * Remove mbs
1 parent dc832fb commit 2cf4016

File tree

3 files changed

+61
-51
lines changed

3 files changed

+61
-51
lines changed

dist_run.py

+47-32
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,11 @@ def main(args):
273273
pp_rank = pp_mesh.get_local_rank()
274274
tp_group = tp_mesh.get_group()
275275
pp_group = pp_mesh.get_group()
276-
pp_group_size = pp_group.size()
277-
tp_group_size = tp_group.size()
278-
logger.info(f"{pp_group_size=}, {tp_group_size=}")
276+
logger.info(f"{pp_degree=}, {tp_degree=}")
279277

280278
# Convenience variables
281279
first_pp_rank = 0
282-
last_pp_rank = pp_group_size - 1
280+
last_pp_rank = pp_degree - 1
283281

284282
# Assuming same number of GPUs per node
285283
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
@@ -297,18 +295,22 @@ def main(args):
297295
if rank == 0:
298296
logger.info(f"Model: {model}")
299297

300-
mbs = 1 # number of micro-batches
301-
mb_size = 4 # micro-batch size
302-
batch_size = mbs * mb_size # total batch size
303-
298+
# Batch size. Since we push batches dynamically through the pipeline rather
299+
# than chunking them, this is effectively micro-batch size in pipeline
300+
# sense. Thus it is interchangeable with micro-batch size below.
301+
batch_size = 4
304302
seqlen_prefill = 1024 # sequence length
305303
dim = 4096 # embedding dimension
306304

307305
# Setup KV caches (after model distribution)
308-
# TODO: the setting below only works for 1 micro-batch case. To support
309-
# multiple micro-batches, we need the KV cache in the model to be aware of
310-
# the number of micro-batches and the current micro-batch index.
311-
model.setup_caches(mb_size, seqlen_prefill)
306+
# The number of cache lanes is the same as the maximum number of
307+
# micro-batches that can be "in flight" in parallel -- imagine each
308+
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
309+
# When decoding is done for certain micro-batches, we can reuse the KV cache
310+
# lanes.
311+
# TODO: bump up the lane count
312+
pipeline_lanes = 1
313+
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)
312314

313315
# Load weights
314316
logger.info(f"Loading weights for {pp_rank=} on {device=}")
@@ -317,7 +319,7 @@ def main(args):
317319
model.to(device)
318320

319321
logger.info(
320-
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
322+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
321323
)
322324

323325
# info on stage size and params
@@ -330,17 +332,16 @@ def main(args):
330332

331333
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
332334
input_pos = torch.arange(seqlen_prefill, device=device)
333-
model.setup_input_pos(input_pos)
334335
model.eval()
335336

336337
# Helper function to get example inputs and outputs for the stages.
337338
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
338-
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
339+
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
339340
activation = torch.rand(
340-
mb_size, seqlen, dim, device=device, dtype=model_dtype
341+
batch_size, seqlen, dim, device=device, dtype=model_dtype
341342
)
342343
logits = torch.rand(
343-
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
344+
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
344345
)
345346
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
346347
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
@@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
358359
output_args=example_outputs,
359360
group=pp_group,
360361
)
361-
# create schedule
362-
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)
362+
363+
# Create schedule
364+
# Number of micro-batches for the schedule is 1, because each step() call we
365+
# only push 1 micro-batch into the pipeline. But we can continuously push
366+
# new micro-batches into the pipeline as they arrive, achieving same
367+
# pipelining effect.
368+
prefiller = ScheduleGPipe(prefill_stage, 1)
363369

364370
prompt = [
365371
"What is a computer?",
@@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
388394
s = set(prompt_lengths)
389395
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
390396

391-
# with CUDATrackTime() as timer:
392397
# Need these global ids due to the API definition of dist.send and recv
393398
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
394399
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
@@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401406
num_tokens = 40
402407

403408
# Prefill phase
404-
# Run context input through pipeline, in 1 step
405-
with torch.no_grad():
409+
# Run context input through pipeline
410+
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
411+
lane = 0
412+
kwargs = {"input_pos": input_pos, "cache_lane": lane}
413+
with torch.no_grad(), CUDATrackTime() as timer:
406414
if pp_rank == first_pp_rank:
407-
output = prefill_schedule.step(padded_sequence)
415+
output = prefiller.step(padded_sequence, **kwargs)
408416
elif pp_rank == last_pp_rank:
409-
output = prefill_schedule.step()
417+
output = prefiller.step(**kwargs)
410418
else: # middle pp ranks
411-
prefill_schedule.step()
419+
prefiller.step(**kwargs)
420+
421+
logger.info(
422+
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
423+
)
412424

413425
# Decode the output -- first generated token
414426
if pp_rank == last_pp_rank:
@@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
430442
# seqlen = 1 now
431443
seqlen_decode = 1
432444
input_pos = torch.tensor([prompt_lengths[0]], device=device)
433-
model.setup_input_pos(input_pos)
434445

435446
# Create decode stage
436447
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
445456
group=pp_group,
446457
)
447458
# create schedule
448-
decode_schedule = ScheduleGPipe(decode_stage, mbs)
459+
decorder = ScheduleGPipe(decode_stage, 1)
449460

450461
# Decoding
451-
with torch.no_grad():
462+
with torch.no_grad(), CUDATrackTime() as timer:
452463
for step in range(num_tokens - 1):
464+
kwargs = {"input_pos": input_pos, "cache_lane": lane}
453465
# sendrecv between last and first ranks, only if:
454466
# first_pp_rank != last_pp_rank.
455467
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
467479

468480
# Run data through pipeline
469481
if pp_rank == first_pp_rank:
470-
output = decode_schedule.step(new_token)
482+
output = decorder.step(new_token, **kwargs)
471483
elif pp_rank == last_pp_rank:
472-
output = decode_schedule.step()
484+
output = decorder.step(**kwargs)
473485
else: # middle pp ranks
474-
decode_schedule.step()
486+
decorder.step(**kwargs)
475487

476488
# Decode the output
477489
if pp_rank == last_pp_rank:
@@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
491503
) # decode_results[i][0]
492504

493505
input_pos += 1
494-
model.setup_input_pos(input_pos)
506+
507+
logger.info(
508+
f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
509+
)
495510

496511
# Display the decoding results
497512

torchchat/export.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ def __init__(self, attention: Attention):
152152
self.wo = attention.wo
153153

154154
max_batch_size, n_heads, max_seq_length, head_dim = (
155-
attention.kv_cache.k_cache.shape
155+
attention.kv_cache[0].k_cache.shape
156156
)
157-
cache_dtype = attention.kv_cache.k_cache.dtype
157+
cache_dtype = attention.kv_cache[0].k_cache.dtype
158158
self.kv_cache = CustomKVCache(
159159
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
160160
)

torchchat/model.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None:
606606
self.max_batch_size = -1
607607
self.max_seq_length = -1
608608

609-
def setup_caches(self, max_batch_size, max_seq_length):
609+
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
610610
if (
611611
self.max_seq_length >= max_seq_length
612612
and self.max_batch_size >= max_batch_size
@@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
620620
# parallelism may have been applied there and the `n_local_heads``
621621
# value being adjusted.
622622
b.attention.setup_cache(
623-
max_batch_size, max_seq_length,
623+
max_batch_size, max_seq_length, cache_lanes=cache_lanes
624624
)
625625

626626
freqs_cis = precompute_freqs_cis(
@@ -653,22 +653,15 @@ def distribute(self, device_mesh: DeviceMesh):
653653
ColwiseParallel(output_layouts=Replicate()),
654654
)
655655

656-
# This is a temporary solution to pass input_pos to non-0 pipeline stages
657-
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
658-
def setup_input_pos(self, input_pos: Tensor) -> None:
659-
self._input_pos = input_pos
660-
661-
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
656+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
662657
assert self.freqs_cis is not None, "Caches must be initialized first"
663-
# TODO: find a better way to pass input_pos to non-0 pipeline stages
664-
input_pos = input_pos if input_pos is not None else self._input_pos
665658
mask = self.causal_mask[None, None, input_pos]
666659
freqs_cis = self.freqs_cis[input_pos]
667660
if self.tok_embeddings:
668661
x = self.tok_embeddings(x)
669662

670663
for _, layer in self.layers.items():
671-
x = layer(x, input_pos, freqs_cis, mask)
664+
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
672665

673666
if self.norm:
674667
x = self.norm(x)
@@ -691,7 +684,7 @@ def distribute(self, device_mesh: DeviceMesh):
691684
self.feed_forward.distribute(device_mesh)
692685

693686
def forward(
694-
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
687+
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
695688
) -> Tensor:
696689
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
697690
out = h + self.feed_forward(self.ffn_norm(h))
@@ -723,15 +716,16 @@ def __init__(self, config: TransformerArgs):
723716
self.dim = config.dim
724717
self._register_load_state_dict_pre_hook(self.load_hook)
725718

726-
def setup_cache(self, max_batch_size, max_seq_length):
719+
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
727720
n_local_heads = self.n_local_heads
728721
# If TP is enabled, the heads would be divided and assigned to different ranks
729722
if hasattr(self, "tp_degree"):
730723
n_local_heads = self.n_local_heads // self.tp_degree
731724

732-
self.kv_cache = KVCache(
733-
max_batch_size, max_seq_length, n_local_heads, self.head_dim
734-
)
725+
self.kv_cache = nn.ModuleList([
726+
KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim)
727+
for _ in range(cache_lanes)
728+
])
735729

736730
def load_hook(self, state_dict, prefix, *args):
737731
# if prefix + "wq.weight" in state_dict:
@@ -784,6 +778,7 @@ def forward(
784778
freqs_cis: Tensor,
785779
mask: Tensor,
786780
input_pos: Optional[Tensor] = None,
781+
cache_lane: int = 0,
787782
) -> Tensor:
788783
bsz, seqlen, _ = x.shape
789784

@@ -809,7 +804,7 @@ def forward(
809804
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
810805

811806
if self.kv_cache is not None:
812-
k, v = self.kv_cache.update(input_pos, k, v)
807+
k, v = self.kv_cache[cache_lane].update(input_pos, k, v)
813808

814809
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
815810
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)

0 commit comments

Comments
 (0)