Skip to content

Commit fc5033e

Browse files
committed
Better inputs, experiments
1 parent 75279ff commit fc5033e

8 files changed

+204
-46
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ safetensors
44
deepspeed==0.7.7
55
-e ./transformers
66
flash-attn
7+
einops
78

89
# TODO: Analysis only
910
py-markdown-table
Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11

22
# Santacoder
3-
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0
4-
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0
5-
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0
3+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0 v2_
4+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 v2_
5+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0 v2_
66

7-
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1
8-
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1
9-
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1
7+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1 v2_
8+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 v2_
9+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 v2_
1010

1111
# Large model
12-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0
13-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0
14-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0
15-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 # OOM?
12+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 v2_
13+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 v2_
14+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 v2_
15+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 v2_# OOM?
1616

17-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1
18-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1
19-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1
20-
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 # OOM?
17+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 v2_ 1
18+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 v2_ 1
19+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 v2_ 1
20+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 v2_ 1 # OOM?

scripts/run_benchmark_breakdown.sh

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ MAX_NEW_TOKENS=${4:-2040}
1010
# Prime number to see key length padding effect.
1111
TOKEN_STEP=${5:-5}
1212
STEP_ID=${6:-""}
13+
FILE_PREFIX=${7:-""}
14+
CYCLES=${8:-10}
1315

1416
SAVE_DIR=data/benchmarks/v2
1517
#BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256"
@@ -19,37 +21,55 @@ RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --cu
1921
RUNTIME=("" "pre_allocate_kv_cache=True" "pre_allocate_kv_cache=True inference_runner=3")
2022
RUNTIME_NAMES=("base" "pre_allocate" "graph")
2123

22-
ATTN_NAME=("jit" "flash" "torch" "torchflash" "torchmem" "torchcpp")
24+
ATTN=( \
25+
"attention_implementation=0" \
26+
"attention_implementation=1" \
27+
"attention_implementation=1 --pad_generated_tokens=0.5" \
28+
"attention_implementation=2" \
29+
"attention_implementation=0 fused_softmax=False" \
30+
"attention_implementation=0 fused_softmax=True" \
31+
"attention_implementation=3" \
32+
"attention_implementation=4" \
33+
"attention_implementation=5" \
34+
)
35+
ATTN_NAME=( \
36+
"default" \
37+
"flash" \
38+
"flash_unpad_50" \
39+
"torch" \
40+
"no_jit" \
41+
"jit" \
42+
"torchflash" \
43+
"torchmem" \
44+
"torchcpp" \
45+
)
2346

2447

2548
STEP=("--no_prefill" "--no_cache")
2649
STEP_NAME=("decode" "prefill")
2750

28-
COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=10 --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE"
51+
COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=$CYCLES --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE predict_last_token=True"
2952

3053
run () { # run(step, runtime, attn)
31-
FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json
54+
FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"$FILE_PREFIX""${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json
3255
if [ -f "$FILE_NAME" ];
3356
then
3457
echo "Skipping existing $FILE_NAME"
3558
else
36-
$RUN $COMMON ${STEP[$1]} ${RUNTIME[$2]} "attention_implementation=$3" --save="$FILE_NAME"
59+
$RUN $COMMON ${RUNTIME[$2]} ${ATTN[$3]} ${STEP[$1]} --save="$FILE_NAME"
3760
fi
3861
}
3962

4063
if [ "${STEP_ID}" -eq "0" ]
4164
then
42-
# Decode
65+
# Decode (default attn only)
4366
for runtime in {0..2}
4467
do
45-
for attn in {0..5}
46-
do
47-
run 0 $runtime $attn
48-
done
68+
run 0 $runtime 0
4969
done
5070
else
5171
# Prefill (all runtimes are the same)
52-
for attn in {0..5}
72+
for attn in {0..2}
5373
do
5474
run 1 0 $attn
5575
done

src/main.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from src.metrics import Metrics
1212
from src.pipeline import Pipeline, get_pipeline_class
1313
from src.profile import get_profiler, logger
14-
from src.utils import configure_logging, get_dummy_batch, log_dict, log_rank_n, parse_config_args
14+
from src.utils import configure_logging, get_input_batch, log_dict, log_rank_n, parse_config_args
1515

1616

1717
def get_arg_parser() -> ArgumentParser:
@@ -40,6 +40,10 @@ def get_arg_parser() -> ArgumentParser:
4040
# Input and output
4141
parser.add_argument("--batch_size", "-b", default=1, type=int)
4242
parser.add_argument("--max_input_length", "-i", default=-1, type=int)
43+
parser.add_argument("--sample_dir", "-d")
44+
parser.add_argument("--input_pad_ratio", "--pad", default=0, type=float)
45+
parser.add_argument("--pad_generated_tokens", "--pad_g", default=0, type=float)
46+
parser.add_argument("--input_seed", "--seed", default=0, type=int)
4347
parser.add_argument("--max_new_tokens", "-g", default=100, type=int)
4448

4549
# Cleanup
@@ -67,7 +71,6 @@ def main(argv: Optional[List[str]] = None) -> None:
6771
parser = get_arg_parser()
6872
args = parser.parse_args(argv)
6973
config_args = parse_config_args(args.config_args)
70-
inputs = get_dummy_batch(args.batch_size, args.max_input_length)
7174
separate_profile = args.profile and args.profile_cycles is not None
7275
warmup = args.profile if args.warmup is None else args.warmup
7376
if separate_profile:
@@ -94,6 +97,14 @@ def main(argv: Optional[List[str]] = None) -> None:
9497
fast_init=args.fast_init,
9598
trust_remote_code=args.trust_remote_code,
9699
)
100+
inputs = get_input_batch(
101+
args.batch_size,
102+
args.max_input_length,
103+
pipeline.tokenizer,
104+
args.input_pad_ratio,
105+
args.input_seed,
106+
args.sample_dir,
107+
)
97108

98109
all_metrics = []
99110

@@ -145,6 +156,7 @@ def main(argv: Optional[List[str]] = None) -> None:
145156
breakdown_latency=args.breakdown_latency,
146157
key_length_step=args.key_length_step,
147158
ignore_oom=args.ignore_oom,
159+
pad_generated_tokens=args.pad_generated_tokens,
148160
)
149161
if args.profile:
150162
p.step()

src/parse_breakdown_results.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ def get_arg_parser() -> ArgumentParser:
88
parser = ArgumentParser()
99
parser.add_argument("input_dir", type=Path)
1010
parser.add_argument("--title")
11+
parser.add_argument("--size", nargs=2, type=float)
12+
parser.add_argument("--save_dir", "--save", type=Path)
1113
return parser
1214

1315

@@ -22,27 +24,30 @@ def read_data(input_file: Path):
2224
return data
2325

2426

25-
def plot(data, title=None):
27+
def plot(data, title=None, size=None):
2628
import matplotlib.pyplot as plt
2729

28-
fig = plt.figure()
30+
fig = plt.figure(figsize=size)
2931
ax = fig.add_subplot()
3032

31-
for dat in data:
33+
cmap = plt.get_cmap("tab20").colors
34+
cmap = cmap[::2] + cmap[1::2]
35+
36+
for i, dat in enumerate(data):
3237
latency_data = dat["Latency (generate breakdown)"]
3338
ax.plot(
3439
[int(k) for k in latency_data.keys()],
3540
[v * 1000 for v in latency_data.values()],
3641
label=dat["Setting"],
3742
linewidth=1,
43+
color=cmap[i],
3844
) # , linestyle=":")#, markersize=1, marker="o")
3945

4046
ax.set_title(title)
4147
ax.set_xlabel("Sequence length")
4248
ax.set_ylabel("Latency (ms)")
4349
ax.legend()
44-
fig.show()
45-
input("Press enter to continue")
50+
return fig
4651

4752

4853
def main(argv: Optional[List[str]] = None) -> None:
@@ -53,7 +58,23 @@ def main(argv: Optional[List[str]] = None) -> None:
5358
if len(data) == 0:
5459
raise RuntimeError(f"No data to show.")
5560

56-
plot(data, args.title)
61+
title = args.title
62+
dirname = args.input_dir.stem
63+
if title is None:
64+
try:
65+
name, _, bs, _, _, _, _, step = dirname.rsplit("_", 7)
66+
title = f"{name} {step}, bs = {bs}"
67+
except ValueError:
68+
title = dirname
69+
70+
fig = plot(data, title, args.size)
71+
fig.show()
72+
if args.save_dir:
73+
save_path = (args.save_dir / dirname).with_suffix(".jpg")
74+
fig.savefig(save_path)
75+
print(f"Figure saved to {save_path}")
76+
77+
input("Press enter to continue")
5778

5879

5980
if __name__ == "__main__":

src/pipeline.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def _generate_custom(
216216
breakdown_latency: bool = False,
217217
key_length_step: int = 1,
218218
ignore_oom: bool = False,
219+
pad_generated_tokens: float = 0,
219220
):
220221
t0 = self._get_time(breakdown_latency)
221222
batch_size, input_length = inputs["input_ids"].shape
@@ -227,7 +228,13 @@ def _generate_custom(
227228

228229
attention_mask = torch.empty([batch_size, output_length], dtype=torch.bool, device=self.device)
229230
attention_mask[:, :input_length].copy_(inputs["attention_mask"])
230-
attention_mask[:, input_length:].fill_(True)
231+
if pad_generated_tokens > 0:
232+
attention_mask[:, input_length:].copy_(
233+
torch.empty_like(attention_mask[:, input_length:], dtype=torch.float32).uniform_()
234+
> pad_generated_tokens
235+
)
236+
else:
237+
attention_mask[:, input_length:].fill_(True)
231238

232239
position_ids = attention_mask.long().cumsum(-1, dtype=torch.int64) - 1
233240
# TODO: Useless?
@@ -301,6 +308,7 @@ def __call__(
301308
breakdown_latency=False,
302309
key_length_step: int = 1,
303310
ignore_oom: bool = False,
311+
pad_generated_tokens: float = 0,
304312
) -> Tuple[List[str], Dict[str, Any]]:
305313
t0 = self._get_time()
306314
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
@@ -310,13 +318,21 @@ def __call__(
310318
if custom_generate:
311319
assert do_prefill or use_cache
312320
output_tokens, generate_metrics = self._generate_custom(
313-
inputs, max_new_tokens, use_cache, do_prefill, breakdown_latency, key_length_step, ignore_oom
321+
inputs,
322+
max_new_tokens,
323+
use_cache,
324+
do_prefill,
325+
breakdown_latency,
326+
key_length_step,
327+
ignore_oom,
328+
pad_generated_tokens,
314329
)
315330
else:
316331
assert do_prefill
317332
assert not breakdown_latency
318333
assert not ignore_oom
319334
assert key_length_step == 1
335+
assert pad_generated_tokens == 0
320336
output_tokens = self._generate_hf(inputs, max_new_tokens, use_cache)
321337
generate_metrics = {}
322338
t2 = self._get_time(True)

0 commit comments

Comments
 (0)