From 5161d70ac64ae511c68004472c5b5c06cd410efa Mon Sep 17 00:00:00 2001 From: cccclai Date: Mon, 9 Dec 2024 16:47:13 -0800 Subject: [PATCH] Patch llama.py for internal build (#7239) Patch llama.py for internal build (#7239) Summary: As title, minor changes so we can use buck to build.. ``` buck run mode/dev-nosan //executorch/examples/qualcomm/oss_scripts/llama3_2:llama_qnn -- --compile_only --ptq 16a4w --checkpoint /home/chenlai/local/models/consolidated.00.pth --params /home/chenlai/local/models/params.json --tokenizer_model /home/chenlai/local/models/tokenizer.model --prompt "Once" -m SM8650 --model_size 1B --model_mode kv 2>&1 | tee static_llama.log ``` Differential Revision: D66947240 --- examples/qualcomm/oss_scripts/llama2/llama.py | 11 ++++++++--- examples/qualcomm/oss_scripts/llama3_2/llama.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index ae291f3659..ebabfc5ca6 100755 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -492,6 +492,7 @@ def inference(args, pre_gen_pte=""): f"model_mode {args.model_mode} is not implemented yet." ) + assert args.tokenizer_bin is not None, "Need tokenizer model for interence" runner_args = " ".join( [ f"--model_path {pte_filename}.pte", @@ -562,8 +563,7 @@ def post_process(): print(f"Results[{idx}]:\n{output}") -# flake8: noqa: C901 -if __name__ == "__main__": +def main(): parser = setup_common_args_and_variables() parser.add_argument( "-a", @@ -597,7 +597,7 @@ def post_process(): parser.add_argument( "--tokenizer_bin", help="Pass llama2 tokenizer binary.", - required=True, + required=False, type=str, ) @@ -680,3 +680,8 @@ def post_process(): conn.send(json.dumps({"Error": str(e)})) else: raise Exception(e) + + +# flake8: noqa: C901 +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 75c0bb0ff0..50b43e86fb 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -218,7 +218,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): ] + [n.target]: n.meta[QCOM_QUANTIZED_IO] = sharding_type - def quantize(self, quant_dtype, custom_annotations=()): + def quantize(self, quant_dtype, args, custom_annotations=()): self.quant_dtype = quant_dtype quantizer = make_quantizer( quant_dtype=quant_dtype, @@ -386,7 +386,8 @@ def compile(args): if args.ptq != None: start_quantize_ts = time.time() single_llama.quantize( - quant_dtype, + quant_dtype=quant_dtype, + args=args, custom_annotations=( custom_annotate_llama_last_conv_16a8w, matmul_annotate_func, @@ -486,8 +487,7 @@ def post_process(): logging.info(f"Results[{idx}]:\n{output}") -# flake8: noqa: C901 -if __name__ == "__main__": +def main(): parser = setup_common_args_and_variables() parser.add_argument( "-a", @@ -605,3 +605,8 @@ def post_process(): conn.send(json.dumps({"Error": str(e)})) else: raise Exception(e) + + +# flake8: noqa: C901 +if __name__ == "__main__": + main()