From 8a5c7ddf16c3b08dfcf032304efb22616e41bc92 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Apr 2023 17:21:38 -0400 Subject: [PATCH 1/3] Custom generate --- src/main.py | 29 +++++---- src/metrics.py | 10 +++- src/pipeline.py | 153 ++++++++++++++++++++++++++++++++++++++++-------- src/utils.py | 8 ++- transformers | 2 +- 5 files changed, 162 insertions(+), 40 deletions(-) diff --git a/src/main.py b/src/main.py index 4b3287b..9111d35 100644 --- a/src/main.py +++ b/src/main.py @@ -26,16 +26,19 @@ def get_arg_parser() -> ArgumentParser: parser.add_argument("config_args", nargs="*") # Runtime + parser.add_argument("-c", "--custom_generate", action="store_true") parser.add_argument("--pipeline_class", default="HF_Pipeline") parser.add_argument("--device", default="cuda", type=torch.device) parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x)) parser.add_argument("--local_rank", type=int) - parser.add_argument("--no_fast_init", dest="fast_init", action="store_false") + parser.add_argument("--no_fast_init","--nf", dest="fast_init", action="store_false") + parser.add_argument("--no_cache","--nc", dest="use_cache", action="store_false") + parser.add_argument("--no_prefill","--np", dest="do_prefill", action="store_false") # Input and output - parser.add_argument("--batch_size", default=1, type=int) - parser.add_argument("--max_input_length", default=-1, type=int) - parser.add_argument("--max_new_tokens", default=100, type=int) + parser.add_argument("--batch_size","-b", default=1, type=int) + parser.add_argument("--max_input_length","-i", default=-1, type=int) + parser.add_argument("--max_new_tokens","-g", default=100, type=int) # Cleanup parser.add_argument("--clear_every_run", action="store_true") @@ -47,10 +50,11 @@ def get_arg_parser() -> ArgumentParser: # Profiling and logging parser.add_argument("--max_log_outputs", type=int) - parser.add_argument("--profile", action="store_true") - parser.add_argument("--profile_cycles", type=int) - parser.add_argument("--full_trace", action="store_true") - parser.add_argument("--show_op_names", action="store_true") + parser.add_argument("--breakdown_latency","--bl", action="store_true") + parser.add_argument("--profile","-p", action="store_true") + parser.add_argument("--profile_cycles","--pc", type=int) + parser.add_argument("--full_trace","--pt", action="store_true") + parser.add_argument("--show_op_names","--pn", action="store_true") parser.add_argument("--save", type=Path) return parser @@ -61,7 +65,6 @@ def main(argv: Optional[List[str]] = None) -> None: parser = get_arg_parser() args = parser.parse_args(argv) config_args = parse_config_args(args.config_args) - generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False} inputs = get_dummy_batch(args.batch_size, args.max_input_length) separate_profile = args.profile and args.profile_cycles is not None warmup = args.profile if args.warmup is None else args.warmup @@ -88,6 +91,10 @@ def main(argv: Optional[List[str]] = None) -> None: dtype=args.dtype, fast_init=args.fast_init, trust_remote_code=args.trust_remote_code, + custom_generate=args.custom_generate, + use_cache=args.use_cache, + do_prefill=args.do_prefill, + breakdown_latency=args.breakdown_latency, ) all_metrics = [] @@ -104,7 +111,7 @@ def main(argv: Optional[List[str]] = None) -> None: profiler = contextlib.nullcontext() benchmark_metrics = { - **generate_kwargs, + "max_new_tokens": args.max_new_tokens, "Model parameters": pipeline.get_num_parameters(), "Cycles (warmup)": args.skip + warmup, "Cycles (benchmark)": args.cycles, @@ -124,7 +131,7 @@ def main(argv: Optional[List[str]] = None) -> None: if step == args.skip + warmup: t2 = time.perf_counter() benchmark_metrics[Metrics.RUNTIME_WARMUP] = t2 - t1 - generated_text, metrics = pipeline(inputs, **generate_kwargs) + generated_text, metrics = pipeline(inputs, args.max_new_tokens) if args.profile: p.step() diff --git a/src/metrics.py b/src/metrics.py index a2d0858..2b7e61b 100644 --- a/src/metrics.py +++ b/src/metrics.py @@ -17,6 +17,10 @@ def format_ms(t: float) -> str: return f"{1000 * t:.2f} ms" +def format_ms_dict(t_dict: Dict[str,float]) -> Dict[str,str]: + return {key:format_ms(value) for key, value in t_dict.items()} + + def format_mib(m: float) -> str: return f"{m/2**20:.0f} MiB" @@ -24,7 +28,9 @@ def format_mib(m: float) -> str: class Metrics: LATENCY_E2E = "Latency (end to end)" LATENCY_TOKEN = "Latency (tokenization)" - LATENCY_MODEL = "Latency (model)" + LATENCY_MODEL = "Latency (generate)" + LATENCY_GENERATE_START = "Latency (prepare for generation)" + LATENCY_GENERATE_BREAKDOWN = "Latency (generate breakdown)" LATENCY_DECODE = "Latency (decode)" LATENCY_MAX = "Latency (max)" LATENCY_MIN = "Latency (min)" @@ -59,6 +65,8 @@ class Metrics: LATENCY_E2E: format_ms, LATENCY_TOKEN: format_ms, LATENCY_MODEL: format_ms, + LATENCY_GENERATE_START: format_ms, + LATENCY_GENERATE_BREAKDOWN: format_ms_dict, LATENCY_DECODE: format_ms, LATENCY_MAX: format_ms, LATENCY_MIN: format_ms, diff --git a/src/pipeline.py b/src/pipeline.py index 7ae3977..8df6cbc 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -18,6 +18,7 @@ AutoTokenizer, PretrainedConfig, PreTrainedModel, + GPTBigCodeConfig,GPTBigCodeForCausalLM ) @@ -37,25 +38,41 @@ def __init__( dtype: torch.dtype, fast_init: bool = True, trust_remote_code: bool = False, + custom_generate:bool=False, + use_cache: bool = True, + do_prefill: bool = True, + breakdown_latency=False, ): self.global_metrics = {} log_rank_n("*** Setting up tokenizer", logger.info) - t0 = time.perf_counter() - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + t0 = self._get_time() + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, padding_side="left") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token=self.tokenizer.eos_token - self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - t1 = time.perf_counter() + t1 = self._get_time() self.device = device + if self.device==torch.device("cuda"): + self.device=torch.device("cuda:0") + self.dtype = dtype self.is_int8 = self.dtype == torch.int8 self.fast_init = fast_init self.trust_remote_code = trust_remote_code - if self.is_int8 and self.device != torch.device("cuda"): + self.use_cache = use_cache + self.do_prefill = do_prefill + if not self.do_prefill: + assert custom_generate + assert self.use_cache + self.breakdown_latency=breakdown_latency + if self.is_int8 and self.device != torch.device("cuda:0"): raise ValueError(f"Model quantization not supported on device {self.device}") + self._generate=self._generate_custom if custom_generate else self._generate_hf + self.config = self._get_config(model_type, pretrained_config or pretrained_model, config_args) - t2 = time.perf_counter() + t2 = self._get_time() logger.info(f"Model configuration: {self.config}") @@ -67,27 +84,27 @@ def __init__( self.model = self._load_pretrained(pretrained_model) self.model.eval() - t3 = time.perf_counter() + t3 = self._get_time() self.global_metrics[Metrics.INIT_TOKEN] = t1 - t0 self.global_metrics[Metrics.INIT_CONFIG] = t2 - t1 self.global_metrics[Metrics.INIT_TOTAL] = t3 - t0 def _create_model(self) -> PreTrainedModel: - t0 = time.perf_counter() + t0 = self._get_time() log_rank_n("*** Creating model", logger.info) with fast_init(self.device) if self.fast_init else contextlib.nullcontext(): torch_dtype = torch.float16 if self.is_int8 else self.dtype model = AutoModelForCausalLM.from_config( config=self.config, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code ) - t1 = time.perf_counter() + t1 = self._get_time() log_rank_n("*** Moving to device", logger.info) model.to(self.device) - t2 = time.perf_counter() + t2 = self._get_time() log_rank_n("*** Initializing weights", logger.info) # Initialization is ~1000x faster on GPU. model.init_weights() - t3 = time.perf_counter() + t3 = self._get_time() self.global_metrics[Metrics.INIT_CREATE] = t1 - t0 self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1 self.global_metrics[Metrics.INIT_WEIGHTS] = t3 - t2 @@ -101,14 +118,14 @@ def _reload_model(self): self.model = self._load_pretrained("tmp") def _save_pretrained(self, pretrained_model: str): - t0 = time.perf_counter() + t0 = self._get_time() log_rank_n(f"*** Saving model to {pretrained_model}", logger.info) - t1 = time.perf_counter() + t1 = self._get_time() self.global_metrics[Metrics.INIT_SAVE] = t1 - t0 self.model.save_pretrained(pretrained_model) def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel: - t0 = time.perf_counter() + t0 = self._get_time() log_rank_n(f"*** Loading model from {pretrained_model}", logger.info) kwargs = {"load_in_8bit": True, "device_map": "auto"} if self.is_int8 else {"torch_dtype": self.dtype} with fast_init(self.device) if self.fast_init else contextlib.nullcontext(): @@ -120,12 +137,12 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel: trust_remote_code=self.trust_remote_code, **kwargs, ) - t1 = time.perf_counter() + t1 = self._get_time() self.global_metrics["load pretrained model"] = t1 - t0 if not self.is_int8: log_rank_n("*** Moving to device", logger.info) model = model.to(self.device) - t2 = time.perf_counter() + t2 = self._get_time() self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1 return model @@ -171,26 +188,103 @@ def _get_config( return config - def __call__(self, text: List[str], **generate_kwargs) -> Tuple[List[str], Dict[str, Any]]: - t0 = time.perf_counter() - inputs = self.tokenizer(text, return_tensors="pt", padding=True) + def _get_time(self, synchronize=False): + if synchronize: + torch.cuda.synchronize() + return time.perf_counter() + + def _generate_custom(self, inputs:Dict, max_new_tokens:int): + t0 = self._get_time(self.breakdown_latency) + batch_size, input_length = inputs["input_ids"].shape + output_length = input_length + max_new_tokens + input_ids = torch.empty([batch_size, output_length], dtype=torch.int64, device=self.device) + input_ids[:, :input_length].copy_(inputs["input_ids"]) + attention_mask = torch.empty([batch_size, output_length], dtype=torch.bool, device=self.device) + attention_mask[:, :input_length].copy_(inputs["attention_mask"]) + attention_mask[:, input_length:].fill_(True) + + position_ids = attention_mask.long().cumsum(-1, dtype=torch.int64) - 1 + # TODO: Useless? + position_ids[:, :input_length].masked_fill_(attention_mask[:, :input_length] == 0, 1) + + if self.do_prefill or input_length<=1: + past_key_values=None + past_key_length=0 + else: + # Generate mock `past_key_values` + past_key_length=input_length-1 + if isinstance(self.config, GPTBigCodeConfig): + if self.config.pre_allocate_kv_cache: + past_key_values=[past_key_length]*self.config.n_layer + for block in self.model.transformer.h: + block.attn.get_kv_cache(batch_size, past_key_length, dtype=self.dtype, device=self.device).normal_() + else: + kv_dim=self.config.n_embd // self.config.n_head if self.config.multi_query else self.config.n_embd + past_key_values=[torch.randn([batch_size, past_key_length, 2*kv_dim], dtype=self.dtype, device=self.device) for _ in range(self.config.n_layer)] + else: + past_key_values = [ + [torch.randn([batch_size, past_key_length, self.config.n_embd], dtype=self.dtype, device=self.device) for _ in range(2)] for _ in + range(self.config.n_layer)] + + t1 = self._get_time(self.breakdown_latency) + last_time=t1 + generate_times={} + for key_length in range(input_length, output_length): + outputs = self.model( + input_ids=input_ids[:, past_key_length:key_length], + past_key_values=past_key_values, + attention_mask=attention_mask[:, :key_length], + position_ids=position_ids[:, past_key_length:key_length], + return_dict=True, + use_cache=self.use_cache, + ) + if self.use_cache: + past_key_values=outputs.past_key_values + past_key_length=key_length + next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1) + input_ids[:, key_length] = next_tokens + t2 = self._get_time(self.breakdown_latency) + generate_times[key_length]=t2-last_time + last_time=t2 + + metrics={} + if self.breakdown_latency: + metrics[Metrics.LATENCY_GENERATE_START]=t1-t0 + metrics[Metrics.LATENCY_GENERATE_BREAKDOWN]=generate_times + + return input_ids, metrics + + def _generate_hf(self, inputs:Dict, max_new_tokens:int): inputs = {key: value.to(self.device) if torch.is_tensor(value) else value for key, value in inputs.items()} + output = self.model.generate( + **inputs, + return_dict_in_generate=True, + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=self.use_cache, + ) + return output.sequences, {} - t1 = time.perf_counter() - with torch.inference_mode(): - output = self.model.generate(**inputs, return_dict_in_generate=True, **generate_kwargs) - t2 = time.perf_counter() - output_tokens = output.sequences + def __call__(self, text: List[str], max_new_tokens:int) -> Tuple[List[str], Dict[str, Any]]: + t0 = self._get_time() + inputs = self.tokenizer(text, return_tensors="pt", padding=True) + + t1 = self._get_time() + with torch.inference_mode(): + output_tokens, generate_metrics = self._generate(inputs, max_new_tokens) + t2 = self._get_time(True) batch_size, input_length = inputs["input_ids"].shape output_length = output_tokens.size(1) output_text = self.tokenizer.batch_decode(output_tokens.cpu(), skip_special_tokens=True) - t3 = time.perf_counter() + t3 = self._get_time() metrics = { + **generate_metrics, Metrics.BATCH_SIZE: batch_size, Metrics.INPUT_LENGTH: input_length, Metrics.OUTPUT_LENGTH: output_length, @@ -218,14 +312,23 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]): Metrics.TOKENS_BATCH, Metrics.LATENCY_TOKEN, Metrics.LATENCY_MODEL, + Metrics.LATENCY_GENERATE_START, + Metrics.LATENCY_GENERATE_BREAKDOWN, Metrics.LATENCY_DECODE, Metrics.LATENCY_E2E, ) } + + breakdown=all_metrics.pop(Metrics.LATENCY_GENERATE_BREAKDOWN, []) + mean_metrics = {key: np.mean(value).item() for key, value in all_metrics.items() if len(value) > 0} throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_E2E] model_throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_MODEL] + if len(breakdown) > 0: + mean_metrics[Metrics.LATENCY_GENERATE_BREAKDOWN] = { + str(key): np.mean([values[key] for values in breakdown]).item() for key in breakdown[0]} + return { **self.global_metrics, **mean_metrics, diff --git a/src/utils.py b/src/utils.py index 33cfe32..4fbf6d5 100644 --- a/src/utils.py +++ b/src/utils.py @@ -82,9 +82,13 @@ def log_rank_n(msg: str, logger: Callable = logging.info, rank: int = 0): logger(line) -def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0): +def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0, _prefix=""): for key, value in data.items(): - log_rank_n(f"{key}: {value}", logger, rank) + if isinstance(value, dict): + log_rank_n(f"{_prefix}{key}:", logger, rank) + log_dict(value, logger, rank, _prefix+" ") + else: + log_rank_n(f"{_prefix}{key}: {value}", logger, rank) dummy_input_sentences = [ diff --git a/transformers b/transformers index 9c3c548..10f4a98 160000 --- a/transformers +++ b/transformers @@ -1 +1 @@ -Subproject commit 9c3c5484d831484f96e2bcd2961cfac100e52d0b +Subproject commit 10f4a98dbfbbf8e00754267949cf85898e60795a From 75279ff97984a7c80de8f19627dc17f2559080a4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 19 Apr 2023 23:28:41 -0400 Subject: [PATCH 2/3] More options and benchmarking tools --- Dockerfile | 2 +- requirements.txt | 1 + scripts/run_all_benchmark_breakdown.sh | 20 +++ scripts/run_benchmark_breakdown.sh | 56 ++++++++ src/main.py | 46 ++++--- src/metrics.py | 4 +- src/parse_breakdown_results.py | 60 ++++++++ src/pipeline.py | 184 +++++++++++++++---------- src/utils.py | 2 +- transformers | 2 +- 10 files changed, 287 insertions(+), 90 deletions(-) create mode 100755 scripts/run_all_benchmark_breakdown.sh create mode 100755 scripts/run_benchmark_breakdown.sh create mode 100644 src/parse_breakdown_results.py diff --git a/Dockerfile b/Dockerfile index 70739ef..1e35ee0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:23.01-py3 +FROM nvcr.io/nvidia/pytorch:23.03-py3 ARG USER=1000 ARG USERNAME=user diff --git a/requirements.txt b/requirements.txt index 68006a1..603fb3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bitsandbytes safetensors deepspeed==0.7.7 -e ./transformers +flash-attn # TODO: Analysis only py-markdown-table diff --git a/scripts/run_all_benchmark_breakdown.sh b/scripts/run_all_benchmark_breakdown.sh new file mode 100755 index 0000000..d21b0fd --- /dev/null +++ b/scripts/run_all_benchmark_breakdown.sh @@ -0,0 +1,20 @@ + +# Santacoder +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0 +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0 + +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1 +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 + +# Large model +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 # OOM? + +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 # OOM? diff --git a/scripts/run_benchmark_breakdown.sh b/scripts/run_benchmark_breakdown.sh new file mode 100755 index 0000000..120538b --- /dev/null +++ b/scripts/run_benchmark_breakdown.sh @@ -0,0 +1,56 @@ + +# Santacoder prefill. +# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 +# Santacoder decode (fewer data points because slower) +# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 +MODEL_NAME=${1:-"santacoder"} +MODEL_PATH=${2:-"bigcode/gpt_bigcode-santacoder"} +BATCH_SIZE=${3:-32} +MAX_NEW_TOKENS=${4:-2040} +# Prime number to see key length padding effect. +TOKEN_STEP=${5:-5} +STEP_ID=${6:-""} + +SAVE_DIR=data/benchmarks/v2 +#BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256" +RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom" + + +RUNTIME=("" "pre_allocate_kv_cache=True" "pre_allocate_kv_cache=True inference_runner=3") +RUNTIME_NAMES=("base" "pre_allocate" "graph") + +ATTN_NAME=("jit" "flash" "torch" "torchflash" "torchmem" "torchcpp") + + +STEP=("--no_prefill" "--no_cache") +STEP_NAME=("decode" "prefill") + +COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=10 --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE" + +run () { # run(step, runtime, attn) + FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json + if [ -f "$FILE_NAME" ]; + then + echo "Skipping existing $FILE_NAME" + else + $RUN $COMMON ${STEP[$1]} ${RUNTIME[$2]} "attention_implementation=$3" --save="$FILE_NAME" + fi +} + +if [ "${STEP_ID}" -eq "0" ] +then + # Decode + for runtime in {0..2} + do + for attn in {0..5} + do + run 0 $runtime $attn + done + done +else + # Prefill (all runtimes are the same) + for attn in {0..5} + do + run 1 0 $attn + done +fi diff --git a/src/main.py b/src/main.py index 9111d35..294298e 100644 --- a/src/main.py +++ b/src/main.py @@ -31,14 +31,16 @@ def get_arg_parser() -> ArgumentParser: parser.add_argument("--device", default="cuda", type=torch.device) parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x)) parser.add_argument("--local_rank", type=int) - parser.add_argument("--no_fast_init","--nf", dest="fast_init", action="store_false") - parser.add_argument("--no_cache","--nc", dest="use_cache", action="store_false") - parser.add_argument("--no_prefill","--np", dest="do_prefill", action="store_false") + parser.add_argument("--no_fast_init", "--nf", dest="fast_init", action="store_false") + parser.add_argument("--no_cache", "--nc", dest="use_cache", action="store_false") + parser.add_argument("--no_prefill", "--np", dest="do_prefill", action="store_false") + parser.add_argument("--key_length_step", "--ks", default=1, type=int) + parser.add_argument("--ignore_oom", "--oom", action="store_true") # Input and output - parser.add_argument("--batch_size","-b", default=1, type=int) - parser.add_argument("--max_input_length","-i", default=-1, type=int) - parser.add_argument("--max_new_tokens","-g", default=100, type=int) + parser.add_argument("--batch_size", "-b", default=1, type=int) + parser.add_argument("--max_input_length", "-i", default=-1, type=int) + parser.add_argument("--max_new_tokens", "-g", default=100, type=int) # Cleanup parser.add_argument("--clear_every_run", action="store_true") @@ -50,11 +52,11 @@ def get_arg_parser() -> ArgumentParser: # Profiling and logging parser.add_argument("--max_log_outputs", type=int) - parser.add_argument("--breakdown_latency","--bl", action="store_true") - parser.add_argument("--profile","-p", action="store_true") - parser.add_argument("--profile_cycles","--pc", type=int) - parser.add_argument("--full_trace","--pt", action="store_true") - parser.add_argument("--show_op_names","--pn", action="store_true") + parser.add_argument("--breakdown_latency", "--bl", action="store_true") + parser.add_argument("--profile", "-p", action="store_true") + parser.add_argument("--profile_cycles", "--pc", type=int) + parser.add_argument("--full_trace", "--pt", action="store_true") + parser.add_argument("--show_op_names", "--pn", action="store_true") parser.add_argument("--save", type=Path) return parser @@ -91,10 +93,6 @@ def main(argv: Optional[List[str]] = None) -> None: dtype=args.dtype, fast_init=args.fast_init, trust_remote_code=args.trust_remote_code, - custom_generate=args.custom_generate, - use_cache=args.use_cache, - do_prefill=args.do_prefill, - breakdown_latency=args.breakdown_latency, ) all_metrics = [] @@ -128,10 +126,26 @@ def main(argv: Optional[List[str]] = None) -> None: t1 = time.perf_counter() with profiler as p: for step in range(args.skip + warmup + args.cycles): + log_rank_n( + ( + f"*** Running generation step {step} " + f"({'skip' if step str: return f"{1000 * t:.2f} ms" -def format_ms_dict(t_dict: Dict[str,float]) -> Dict[str,str]: - return {key:format_ms(value) for key, value in t_dict.items()} +def format_ms_dict(t_dict: Dict[str, float]) -> Dict[str, str]: + return {key: format_ms(value) for key, value in t_dict.items()} def format_mib(m: float) -> str: diff --git a/src/parse_breakdown_results.py b/src/parse_breakdown_results.py new file mode 100644 index 0000000..82c3bc8 --- /dev/null +++ b/src/parse_breakdown_results.py @@ -0,0 +1,60 @@ +import json +from argparse import ArgumentParser +from pathlib import Path +from typing import List, Optional + + +def get_arg_parser() -> ArgumentParser: + parser = ArgumentParser() + parser.add_argument("input_dir", type=Path) + parser.add_argument("--title") + return parser + + +def read_data(input_file: Path): + try: + with input_file.open("r") as f: + data = json.load(f) + data = {**data["config"], **data["results"]} + except (ValueError, OSError) as e: + raise ValueError(f"Cannot parse file {input_file} ({e})") + data["Setting"] = input_file.stem + return data + + +def plot(data, title=None): + import matplotlib.pyplot as plt + + fig = plt.figure() + ax = fig.add_subplot() + + for dat in data: + latency_data = dat["Latency (generate breakdown)"] + ax.plot( + [int(k) for k in latency_data.keys()], + [v * 1000 for v in latency_data.values()], + label=dat["Setting"], + linewidth=1, + ) # , linestyle=":")#, markersize=1, marker="o") + + ax.set_title(title) + ax.set_xlabel("Sequence length") + ax.set_ylabel("Latency (ms)") + ax.legend() + fig.show() + input("Press enter to continue") + + +def main(argv: Optional[List[str]] = None) -> None: + parser = get_arg_parser() + args = parser.parse_args(argv) + data = [read_data(input_file) for input_file in args.input_dir.iterdir()] + + if len(data) == 0: + raise RuntimeError(f"No data to show.") + + plot(data, args.title) + + +if __name__ == "__main__": + main() diff --git a/src/pipeline.py b/src/pipeline.py index 8df6cbc..0242a3d 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -18,7 +18,7 @@ AutoTokenizer, PretrainedConfig, PreTrainedModel, - GPTBigCodeConfig,GPTBigCodeForCausalLM + GPTBigCodeConfig, ) @@ -38,39 +38,27 @@ def __init__( dtype: torch.dtype, fast_init: bool = True, trust_remote_code: bool = False, - custom_generate:bool=False, - use_cache: bool = True, - do_prefill: bool = True, - breakdown_latency=False, ): self.global_metrics = {} log_rank_n("*** Setting up tokenizer", logger.info) t0 = self._get_time() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, padding_side="left") if self.tokenizer.pad_token is None: - self.tokenizer.pad_token=self.tokenizer.eos_token + self.tokenizer.pad_token = self.tokenizer.eos_token t1 = self._get_time() self.device = device - if self.device==torch.device("cuda"): - self.device=torch.device("cuda:0") + if self.device == torch.device("cuda"): + self.device = torch.device("cuda:0") self.dtype = dtype self.is_int8 = self.dtype == torch.int8 self.fast_init = fast_init self.trust_remote_code = trust_remote_code - self.use_cache = use_cache - self.do_prefill = do_prefill - if not self.do_prefill: - assert custom_generate - assert self.use_cache - self.breakdown_latency=breakdown_latency if self.is_int8 and self.device != torch.device("cuda:0"): raise ValueError(f"Model quantization not supported on device {self.device}") - self._generate=self._generate_custom if custom_generate else self._generate_hf - self.config = self._get_config(model_type, pretrained_config or pretrained_model, config_args) t2 = self._get_time() @@ -193,12 +181,49 @@ def _get_time(self, synchronize=False): torch.cuda.synchronize() return time.perf_counter() - def _generate_custom(self, inputs:Dict, max_new_tokens:int): - t0 = self._get_time(self.breakdown_latency) + def _allocate_mock_cache(self, past_key_length: int, batch_size: int): + if isinstance(self.config, GPTBigCodeConfig): + if self.config.pre_allocate_kv_cache: + past_key_values = [past_key_length] * self.config.n_layer + for block in self.model.transformer.h: + block.attn.get_kv_cache( + batch_size, past_key_length, dtype=self.dtype, device=self.device + ).normal_() + else: + kv_dim = self.config.n_embd // self.config.n_head if self.config.multi_query else self.config.n_embd + past_key_values = [ + torch.randn([batch_size, past_key_length, 2 * kv_dim], dtype=self.dtype, device=self.device) + for _ in range(self.config.n_layer) + ] + else: + past_key_values = [ + [ + torch.randn( + [batch_size, past_key_length, self.config.n_embd], dtype=self.dtype, device=self.device + ) + for _ in range(2) + ] + for _ in range(self.config.n_layer) + ] + return past_key_values + + def _generate_custom( + self, + inputs: Dict, + max_new_tokens: int, + use_cache: bool = True, + do_prefill: bool = True, + breakdown_latency: bool = False, + key_length_step: int = 1, + ignore_oom: bool = False, + ): + t0 = self._get_time(breakdown_latency) batch_size, input_length = inputs["input_ids"].shape output_length = input_length + max_new_tokens input_ids = torch.empty([batch_size, output_length], dtype=torch.int64, device=self.device) input_ids[:, :input_length].copy_(inputs["input_ids"]) + if key_length_step > 1: + input_ids[:, input_length:].fill_(self.tokenizer.pad_token_id) attention_mask = torch.empty([batch_size, output_length], dtype=torch.bool, device=self.device) attention_mask[:, :input_length].copy_(inputs["attention_mask"]) @@ -208,54 +233,53 @@ def _generate_custom(self, inputs:Dict, max_new_tokens:int): # TODO: Useless? position_ids[:, :input_length].masked_fill_(attention_mask[:, :input_length] == 0, 1) - if self.do_prefill or input_length<=1: - past_key_values=None - past_key_length=0 - else: - # Generate mock `past_key_values` - past_key_length=input_length-1 - if isinstance(self.config, GPTBigCodeConfig): - if self.config.pre_allocate_kv_cache: - past_key_values=[past_key_length]*self.config.n_layer - for block in self.model.transformer.h: - block.attn.get_kv_cache(batch_size, past_key_length, dtype=self.dtype, device=self.device).normal_() + t1 = self._get_time(breakdown_latency) + last_time = t1 + past_key_length = 0 + past_key_values = None + generate_times = {} + for key_length in range(input_length, output_length, key_length_step): + try: + if ( + use_cache + and (past_key_values is None and not do_prefill) + or (past_key_values is not None and key_length_step > 1) + ): + past_key_length = key_length - 1 + past_key_values = self._allocate_mock_cache(past_key_length, batch_size) + # Exclude cache creation from timing + last_time = self._get_time(breakdown_latency) + outputs = self.model( + input_ids=input_ids[:, past_key_length:key_length], + past_key_values=past_key_values, + attention_mask=attention_mask[:, :key_length], + position_ids=position_ids[:, past_key_length:key_length], + return_dict=True, + use_cache=use_cache, + ) + if use_cache: + past_key_values = outputs.past_key_values + past_key_length = key_length + next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1) + input_ids[:, key_length] = next_tokens + t2 = self._get_time(breakdown_latency) + generate_times[key_length] = t2 - last_time + last_time = t2 + except torch.cuda.OutOfMemoryError: + if ignore_oom: + logger.warning(f"Out of memory at key length {key_length}") + break else: - kv_dim=self.config.n_embd // self.config.n_head if self.config.multi_query else self.config.n_embd - past_key_values=[torch.randn([batch_size, past_key_length, 2*kv_dim], dtype=self.dtype, device=self.device) for _ in range(self.config.n_layer)] - else: - past_key_values = [ - [torch.randn([batch_size, past_key_length, self.config.n_embd], dtype=self.dtype, device=self.device) for _ in range(2)] for _ in - range(self.config.n_layer)] - - t1 = self._get_time(self.breakdown_latency) - last_time=t1 - generate_times={} - for key_length in range(input_length, output_length): - outputs = self.model( - input_ids=input_ids[:, past_key_length:key_length], - past_key_values=past_key_values, - attention_mask=attention_mask[:, :key_length], - position_ids=position_ids[:, past_key_length:key_length], - return_dict=True, - use_cache=self.use_cache, - ) - if self.use_cache: - past_key_values=outputs.past_key_values - past_key_length=key_length - next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1) - input_ids[:, key_length] = next_tokens - t2 = self._get_time(self.breakdown_latency) - generate_times[key_length]=t2-last_time - last_time=t2 - - metrics={} - if self.breakdown_latency: - metrics[Metrics.LATENCY_GENERATE_START]=t1-t0 - metrics[Metrics.LATENCY_GENERATE_BREAKDOWN]=generate_times + raise + + metrics = {} + if breakdown_latency: + metrics[Metrics.LATENCY_GENERATE_START] = t1 - t0 + metrics[Metrics.LATENCY_GENERATE_BREAKDOWN] = generate_times return input_ids, metrics - def _generate_hf(self, inputs:Dict, max_new_tokens:int): + def _generate_hf(self, inputs: Dict, max_new_tokens: int, use_cache: bool): inputs = {key: value.to(self.device) if torch.is_tensor(value) else value for key, value in inputs.items()} output = self.model.generate( **inputs, @@ -263,18 +287,38 @@ def _generate_hf(self, inputs:Dict, max_new_tokens:int): max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, - use_cache=self.use_cache, + use_cache=use_cache, ) - return output.sequences, {} - + return output.sequences - def __call__(self, text: List[str], max_new_tokens:int) -> Tuple[List[str], Dict[str, Any]]: + def __call__( + self, + text: List[str], + max_new_tokens: int, + custom_generate: bool = False, + use_cache: bool = True, + do_prefill: bool = True, + breakdown_latency=False, + key_length_step: int = 1, + ignore_oom: bool = False, + ) -> Tuple[List[str], Dict[str, Any]]: t0 = self._get_time() inputs = self.tokenizer(text, return_tensors="pt", padding=True) t1 = self._get_time() with torch.inference_mode(): - output_tokens, generate_metrics = self._generate(inputs, max_new_tokens) + if custom_generate: + assert do_prefill or use_cache + output_tokens, generate_metrics = self._generate_custom( + inputs, max_new_tokens, use_cache, do_prefill, breakdown_latency, key_length_step, ignore_oom + ) + else: + assert do_prefill + assert not breakdown_latency + assert not ignore_oom + assert key_length_step == 1 + output_tokens = self._generate_hf(inputs, max_new_tokens, use_cache) + generate_metrics = {} t2 = self._get_time(True) batch_size, input_length = inputs["input_ids"].shape @@ -319,7 +363,7 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]): ) } - breakdown=all_metrics.pop(Metrics.LATENCY_GENERATE_BREAKDOWN, []) + breakdown = all_metrics.pop(Metrics.LATENCY_GENERATE_BREAKDOWN, []) mean_metrics = {key: np.mean(value).item() for key, value in all_metrics.items() if len(value) > 0} throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_E2E] @@ -327,7 +371,9 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]): if len(breakdown) > 0: mean_metrics[Metrics.LATENCY_GENERATE_BREAKDOWN] = { - str(key): np.mean([values[key] for values in breakdown]).item() for key in breakdown[0]} + str(key): np.mean([values[key] for values in breakdown if key in values]).item() + for key in breakdown[0] + } return { **self.global_metrics, diff --git a/src/utils.py b/src/utils.py index 4fbf6d5..6331c56 100644 --- a/src/utils.py +++ b/src/utils.py @@ -86,7 +86,7 @@ def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0, _prefix for key, value in data.items(): if isinstance(value, dict): log_rank_n(f"{_prefix}{key}:", logger, rank) - log_dict(value, logger, rank, _prefix+" ") + log_dict(value, logger, rank, _prefix + " ") else: log_rank_n(f"{_prefix}{key}: {value}", logger, rank) diff --git a/transformers b/transformers index 10f4a98..d08ce17 160000 --- a/transformers +++ b/transformers @@ -1 +1 @@ -Subproject commit 10f4a98dbfbbf8e00754267949cf85898e60795a +Subproject commit d08ce174360daf9d49425859882720b4c50b6813 From fc5033e99cf5376b9365a0603d352bb4cc13b243 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 24 Apr 2023 10:49:35 -0400 Subject: [PATCH 3/3] Better inputs, experiments --- requirements.txt | 1 + scripts/run_all_benchmark_breakdown.sh | 28 +++---- scripts/run_benchmark_breakdown.sh | 40 ++++++--- src/main.py | 16 +++- src/parse_breakdown_results.py | 33 ++++++-- src/pipeline.py | 20 ++++- src/utils.py | 110 ++++++++++++++++++++++--- transformers | 2 +- 8 files changed, 204 insertions(+), 46 deletions(-) diff --git a/requirements.txt b/requirements.txt index 603fb3a..5144b94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ safetensors deepspeed==0.7.7 -e ./transformers flash-attn +einops # TODO: Analysis only py-markdown-table diff --git a/scripts/run_all_benchmark_breakdown.sh b/scripts/run_all_benchmark_breakdown.sh index d21b0fd..818ddc4 100755 --- a/scripts/run_all_benchmark_breakdown.sh +++ b/scripts/run_all_benchmark_breakdown.sh @@ -1,20 +1,20 @@ # Santacoder -./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0 -./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 -./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0 +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0 v2_ +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0 v2_ +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0 v2_ -./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1 -./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 -./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1 v2_ +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1 v2_ +./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1 v2_ # Large model -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 # OOM? +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0 v2_ +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0 v2_ +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0 v2_ +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 v2_# OOM? -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 -./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 # OOM? +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1 v2_ 1 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1 v2_ 1 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1 v2_ 1 +./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 v2_ 1 # OOM? diff --git a/scripts/run_benchmark_breakdown.sh b/scripts/run_benchmark_breakdown.sh index 120538b..5781a13 100755 --- a/scripts/run_benchmark_breakdown.sh +++ b/scripts/run_benchmark_breakdown.sh @@ -10,6 +10,8 @@ MAX_NEW_TOKENS=${4:-2040} # Prime number to see key length padding effect. TOKEN_STEP=${5:-5} STEP_ID=${6:-""} +FILE_PREFIX=${7:-""} +CYCLES=${8:-10} SAVE_DIR=data/benchmarks/v2 #BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256" @@ -19,37 +21,55 @@ RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --cu RUNTIME=("" "pre_allocate_kv_cache=True" "pre_allocate_kv_cache=True inference_runner=3") RUNTIME_NAMES=("base" "pre_allocate" "graph") -ATTN_NAME=("jit" "flash" "torch" "torchflash" "torchmem" "torchcpp") +ATTN=( \ + "attention_implementation=0" \ + "attention_implementation=1" \ + "attention_implementation=1 --pad_generated_tokens=0.5" \ + "attention_implementation=2" \ + "attention_implementation=0 fused_softmax=False" \ + "attention_implementation=0 fused_softmax=True" \ + "attention_implementation=3" \ + "attention_implementation=4" \ + "attention_implementation=5" \ + ) +ATTN_NAME=( \ + "default" \ + "flash" \ + "flash_unpad_50" \ + "torch" \ + "no_jit" \ + "jit" \ + "torchflash" \ + "torchmem" \ + "torchcpp" \ + ) STEP=("--no_prefill" "--no_cache") STEP_NAME=("decode" "prefill") -COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=10 --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE" +COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=$CYCLES --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE predict_last_token=True" run () { # run(step, runtime, attn) - FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json + FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"$FILE_PREFIX""${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json if [ -f "$FILE_NAME" ]; then echo "Skipping existing $FILE_NAME" else - $RUN $COMMON ${STEP[$1]} ${RUNTIME[$2]} "attention_implementation=$3" --save="$FILE_NAME" + $RUN $COMMON ${RUNTIME[$2]} ${ATTN[$3]} ${STEP[$1]} --save="$FILE_NAME" fi } if [ "${STEP_ID}" -eq "0" ] then - # Decode + # Decode (default attn only) for runtime in {0..2} do - for attn in {0..5} - do - run 0 $runtime $attn - done + run 0 $runtime 0 done else # Prefill (all runtimes are the same) - for attn in {0..5} + for attn in {0..2} do run 1 0 $attn done diff --git a/src/main.py b/src/main.py index 294298e..e42b929 100644 --- a/src/main.py +++ b/src/main.py @@ -11,7 +11,7 @@ from src.metrics import Metrics from src.pipeline import Pipeline, get_pipeline_class from src.profile import get_profiler, logger -from src.utils import configure_logging, get_dummy_batch, log_dict, log_rank_n, parse_config_args +from src.utils import configure_logging, get_input_batch, log_dict, log_rank_n, parse_config_args def get_arg_parser() -> ArgumentParser: @@ -40,6 +40,10 @@ def get_arg_parser() -> ArgumentParser: # Input and output parser.add_argument("--batch_size", "-b", default=1, type=int) parser.add_argument("--max_input_length", "-i", default=-1, type=int) + parser.add_argument("--sample_dir", "-d") + parser.add_argument("--input_pad_ratio", "--pad", default=0, type=float) + parser.add_argument("--pad_generated_tokens", "--pad_g", default=0, type=float) + parser.add_argument("--input_seed", "--seed", default=0, type=int) parser.add_argument("--max_new_tokens", "-g", default=100, type=int) # Cleanup @@ -67,7 +71,6 @@ def main(argv: Optional[List[str]] = None) -> None: parser = get_arg_parser() args = parser.parse_args(argv) config_args = parse_config_args(args.config_args) - inputs = get_dummy_batch(args.batch_size, args.max_input_length) separate_profile = args.profile and args.profile_cycles is not None warmup = args.profile if args.warmup is None else args.warmup if separate_profile: @@ -94,6 +97,14 @@ def main(argv: Optional[List[str]] = None) -> None: fast_init=args.fast_init, trust_remote_code=args.trust_remote_code, ) + inputs = get_input_batch( + args.batch_size, + args.max_input_length, + pipeline.tokenizer, + args.input_pad_ratio, + args.input_seed, + args.sample_dir, + ) all_metrics = [] @@ -145,6 +156,7 @@ def main(argv: Optional[List[str]] = None) -> None: breakdown_latency=args.breakdown_latency, key_length_step=args.key_length_step, ignore_oom=args.ignore_oom, + pad_generated_tokens=args.pad_generated_tokens, ) if args.profile: p.step() diff --git a/src/parse_breakdown_results.py b/src/parse_breakdown_results.py index 82c3bc8..4c281cf 100644 --- a/src/parse_breakdown_results.py +++ b/src/parse_breakdown_results.py @@ -8,6 +8,8 @@ def get_arg_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument("input_dir", type=Path) parser.add_argument("--title") + parser.add_argument("--size", nargs=2, type=float) + parser.add_argument("--save_dir", "--save", type=Path) return parser @@ -22,27 +24,30 @@ def read_data(input_file: Path): return data -def plot(data, title=None): +def plot(data, title=None, size=None): import matplotlib.pyplot as plt - fig = plt.figure() + fig = plt.figure(figsize=size) ax = fig.add_subplot() - for dat in data: + cmap = plt.get_cmap("tab20").colors + cmap = cmap[::2] + cmap[1::2] + + for i, dat in enumerate(data): latency_data = dat["Latency (generate breakdown)"] ax.plot( [int(k) for k in latency_data.keys()], [v * 1000 for v in latency_data.values()], label=dat["Setting"], linewidth=1, + color=cmap[i], ) # , linestyle=":")#, markersize=1, marker="o") ax.set_title(title) ax.set_xlabel("Sequence length") ax.set_ylabel("Latency (ms)") ax.legend() - fig.show() - input("Press enter to continue") + return fig def main(argv: Optional[List[str]] = None) -> None: @@ -53,7 +58,23 @@ def main(argv: Optional[List[str]] = None) -> None: if len(data) == 0: raise RuntimeError(f"No data to show.") - plot(data, args.title) + title = args.title + dirname = args.input_dir.stem + if title is None: + try: + name, _, bs, _, _, _, _, step = dirname.rsplit("_", 7) + title = f"{name} {step}, bs = {bs}" + except ValueError: + title = dirname + + fig = plot(data, title, args.size) + fig.show() + if args.save_dir: + save_path = (args.save_dir / dirname).with_suffix(".jpg") + fig.savefig(save_path) + print(f"Figure saved to {save_path}") + + input("Press enter to continue") if __name__ == "__main__": diff --git a/src/pipeline.py b/src/pipeline.py index 0242a3d..03f8c0d 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -216,6 +216,7 @@ def _generate_custom( breakdown_latency: bool = False, key_length_step: int = 1, ignore_oom: bool = False, + pad_generated_tokens: float = 0, ): t0 = self._get_time(breakdown_latency) batch_size, input_length = inputs["input_ids"].shape @@ -227,7 +228,13 @@ def _generate_custom( attention_mask = torch.empty([batch_size, output_length], dtype=torch.bool, device=self.device) attention_mask[:, :input_length].copy_(inputs["attention_mask"]) - attention_mask[:, input_length:].fill_(True) + if pad_generated_tokens > 0: + attention_mask[:, input_length:].copy_( + torch.empty_like(attention_mask[:, input_length:], dtype=torch.float32).uniform_() + > pad_generated_tokens + ) + else: + attention_mask[:, input_length:].fill_(True) position_ids = attention_mask.long().cumsum(-1, dtype=torch.int64) - 1 # TODO: Useless? @@ -301,6 +308,7 @@ def __call__( breakdown_latency=False, key_length_step: int = 1, ignore_oom: bool = False, + pad_generated_tokens: float = 0, ) -> Tuple[List[str], Dict[str, Any]]: t0 = self._get_time() inputs = self.tokenizer(text, return_tensors="pt", padding=True) @@ -310,13 +318,21 @@ def __call__( if custom_generate: assert do_prefill or use_cache output_tokens, generate_metrics = self._generate_custom( - inputs, max_new_tokens, use_cache, do_prefill, breakdown_latency, key_length_step, ignore_oom + inputs, + max_new_tokens, + use_cache, + do_prefill, + breakdown_latency, + key_length_step, + ignore_oom, + pad_generated_tokens, ) else: assert do_prefill assert not breakdown_latency assert not ignore_oom assert key_length_step == 1 + assert pad_generated_tokens == 0 output_tokens = self._generate_hf(inputs, max_new_tokens, use_cache) generate_metrics = {} t2 = self._get_time(True) diff --git a/src/utils.py b/src/utils.py index 6331c56..9abc913 100644 --- a/src/utils.py +++ b/src/utils.py @@ -3,8 +3,10 @@ import logging.config import math import typing -from typing import Any, Callable, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union +import numpy as np from torch import distributed as dist @@ -91,7 +93,7 @@ def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0, _prefix log_rank_n(f"{_prefix}{key}: {value}", logger, rank) -dummy_input_sentences = [ +dummy_inputs = [ "DeepSpeed is a machine learning framework", "He is working on", "He has a", @@ -103,14 +105,100 @@ def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0, _prefix ] -def get_dummy_batch(batch_size: int, max_input_length: int = -1) -> List[str]: +def get_input_lengths(batch_size, max_input_length, padding_ratio, random_state): + """ + Generate a random set of input lengths with the desired padding ratio and at least one of the specified max length. + """ + if padding_ratio == 0: + return batch_size * [max_input_length] + assert batch_size >= 2 + total_tokens = batch_size * max_input_length + pad_tokens = round(padding_ratio * total_tokens) + input_tokens = total_tokens - pad_tokens + # First length is deterministic + required_tokens = input_tokens - max_input_length + average_length = required_tokens / (batch_size - 1) + smin = 1 + smax = round(2 * average_length - smin) + if smax > max_input_length: + smax = max_input_length + smin = round(2 * average_length - smax) + assert smax >= smin >= 1, "Cannot obtain desired padding ratio" + print("AA", batch_size, max_input_length, padding_ratio, smin, smax) + assert abs(smax + smin - 2 * average_length) < 1 + for i in range(100): + lengths = random_state.randint(smin, smax, batch_size - 2) + remaining = required_tokens - lengths.sum() + if 1 <= remaining <= max_input_length: + lengths = [max_input_length, *lengths.tolist(), remaining] + random_state.shuffle(lengths) + assert sum(lengths) == input_tokens + return lengths + raise RuntimeError("Failed to get desired padding ratio") + + +def get_inputs_from_tokens(tokens, length, tokenizer): + for _ in range(10): + assert len(tokens) == length + inputs = tokenizer.decode(tokens) + # We often get more tokens than we started with, less in som rare cases. + tokens = tokenizer(inputs)["input_ids"] + if len(tokens) == length: + return inputs + tokens = tokens[:length] + max(length - len(tokens), 0) * [tokens[-1]] + raise RuntimeError("Failed to generate stable input sequences") + + +def get_random_inputs(length, tokenizer, random_state): + return get_inputs_from_tokens(random_state.randint(0, tokenizer.vocab_size, length).tolist(), length, tokenizer) + + +def get_inputs_from_files(files: List[Path], lengths, tokenizer, random_state): + file_tokens = [tokenizer(f.open().read())["input_ids"] for f in files] + max_len = max(len(t) for t in file_tokens) + batch_size = len(lengths) + inputs = [] + while len(inputs) < batch_size: + length = lengths[len(inputs)] + if length > max_len: + # No file works, pick at random instead. + inputs.append(get_random_inputs(length, tokenizer, random_state)) + else: + tokens = file_tokens[random_state.randint(len(file_tokens))] + if length > len(tokens): + # Try another file. + continue + start_index = random_state.randint(len(tokens) - length) + inputs.append(get_inputs_from_tokens(tokens[start_index : start_index + length], length, tokenizer)) + return inputs + + +def get_input_batch( + batch_size: int, + max_input_length: int = -1, + tokenizer=None, + padding_ratio: float = 0, + seed: int = 0, + sample_dir: Optional[Union[Path, List[Path]]] = None, +) -> List[str]: if max_input_length == -1: - input_sentences = copy.deepcopy(dummy_input_sentences) + inputs = copy.deepcopy(dummy_inputs) + if batch_size > len(inputs): + inputs *= math.ceil(batch_size / len(inputs)) + return inputs[:batch_size] else: - input_sentences = batch_size * [" Hello" * max_input_length] - - if batch_size > len(input_sentences): - input_sentences *= math.ceil(batch_size / len(input_sentences)) - input_sentences = input_sentences[:batch_size] - - return input_sentences + random_state = np.random.RandomState(seed) + lengths = get_input_lengths(batch_size, max_input_length, padding_ratio, random_state) + if isinstance(sample_dir, Path): + if sample_dir.is_dir(): + sample_dir = [f for f in sample_dir.iterdir() if f.is_file() and f.suffix == ".py"] + elif sample_dir.is_file(): + sample_dir = [sample_dir] + else: + raise FileNotFoundError(sample_dir) + if sample_dir is None: + return get_random_inputs(lengths, tokenizer, random_state) + else: + assert isinstance(sample_dir, List) + assert len(sample_dir) > 0 + return get_inputs_from_files(sample_dir, lengths, tokenizer, random_state) diff --git a/transformers b/transformers index d08ce17..a2efad2 160000 --- a/transformers +++ b/transformers @@ -1 +1 @@ -Subproject commit d08ce174360daf9d49425859882720b4c50b6813 +Subproject commit a2efad2c96e6da982f102eea53918c7b8431da80