Skip to content

Commit

Permalink
Patch llama.py for internal build (#7239)
Browse files Browse the repository at this point in the history
Patch llama.py for internal build (#7239)

Summary:

As title, minor changes so we can use buck to build..

```
buck run mode/dev-nosan //executorch/examples/qualcomm/oss_scripts/llama3_2:llama_qnn -- --compile_only --ptq 16a4w --checkpoint /home/chenlai/local/models/consolidated.00.pth --params /home/chenlai/local/models/params.json --tokenizer_model /home/chenlai/local/models/tokenizer.model --prompt "Once" -m SM8650 --model_size 1B --model_mode kv  2>&1 | tee static_llama.log
```

Differential Revision: D66947240
  • Loading branch information
cccclai authored Dec 10, 2024
1 parent f22d1a3 commit 5161d70
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
11 changes: 8 additions & 3 deletions examples/qualcomm/oss_scripts/llama2/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def inference(args, pre_gen_pte=""):
f"model_mode {args.model_mode} is not implemented yet."
)

assert args.tokenizer_bin is not None, "Need tokenizer model for interence"
runner_args = " ".join(
[
f"--model_path {pte_filename}.pte",
Expand Down Expand Up @@ -562,8 +563,7 @@ def post_process():
print(f"Results[{idx}]:\n{output}")


# flake8: noqa: C901
if __name__ == "__main__":
def main():
parser = setup_common_args_and_variables()
parser.add_argument(
"-a",
Expand Down Expand Up @@ -597,7 +597,7 @@ def post_process():
parser.add_argument(
"--tokenizer_bin",
help="Pass llama2 tokenizer binary.",
required=True,
required=False,
type=str,
)

Expand Down Expand Up @@ -680,3 +680,8 @@ def post_process():
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)


# flake8: noqa: C901
if __name__ == "__main__":
main()
13 changes: 9 additions & 4 deletions examples/qualcomm/oss_scripts/llama3_2/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type):
] + [n.target]:
n.meta[QCOM_QUANTIZED_IO] = sharding_type

def quantize(self, quant_dtype, custom_annotations=()):
def quantize(self, quant_dtype, args, custom_annotations=()):
self.quant_dtype = quant_dtype
quantizer = make_quantizer(
quant_dtype=quant_dtype,
Expand Down Expand Up @@ -386,7 +386,8 @@ def compile(args):
if args.ptq != None:
start_quantize_ts = time.time()
single_llama.quantize(
quant_dtype,
quant_dtype=quant_dtype,
args=args,
custom_annotations=(
custom_annotate_llama_last_conv_16a8w,
matmul_annotate_func,
Expand Down Expand Up @@ -486,8 +487,7 @@ def post_process():
logging.info(f"Results[{idx}]:\n{output}")


# flake8: noqa: C901
if __name__ == "__main__":
def main():
parser = setup_common_args_and_variables()
parser.add_argument(
"-a",
Expand Down Expand Up @@ -605,3 +605,8 @@ def post_process():
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)


# flake8: noqa: C901
if __name__ == "__main__":
main()

0 comments on commit 5161d70

Please sign in to comment.