Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Dispatch for Kernels + Support MKL-based kernels w/ Fallback #122

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

balbit
Copy link

@balbit balbit commented Dec 5, 2024

Problem

  • All kernel implementations consolidated in matmul.h, under monolithic MatmulOperator class
    • Includes messy list of various matmul implementations, forcing downstream client #ifdef handling
    • Difficult to migrate kernels
    • Fallback implementations require duplicating all matmul code (present across cuda, metal, and mkl)
    • Can't compile >1 kernel at once
  • Metal kernel missing support

Changes

Migrate MatmulOperator Structure

  • Change MatmulOperator to virtual base class
    • Consolidate kernel deployment in factory
  • Corresponding Makefile modifications
  • Subclass for each kernel
  • Migrate callsites (in llm/ops) to use references
  • Test fallbacking

Per-kernel migration

  • AVX
  • MKL with AVX fallback
  • CUDA
  • Neon
  • Metal

Metal Changes

Main branch Metal would fail to build due to missing operations

  • Auto downloads metal-cpp
  • Fix operations
Undefined symbols for architecture arm64:
  "matmul::MatmulOperator::mat_mul_accelerator_transposed_fastover_column_bias(matmul_params const*)", referenced from:
      Linear_FP::forward(Matrix3D<float> const&, Matrix3D<float>&) in linear.o
ld: symbol(s) not found for architecture arm64
clang++: error: linker command failed with exit code 1 (use -v to see invocation)

Instructions for MKL Setup

Testing

CUDA

(TinyChatEngine) elliotliu@hanlab-MSI-4090:~/TinyChatEngineMain/llm$ ./chat LLaMA2_7B_chat INT4 8
TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine
Using model: LLaMA2_7B_chat
Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq
Loading model... Finished!

USER: Hi
ASSISTANT:  Hello! How may I assist you today?


Inference latency, Total time: 8.8 s, 735.3 ms/token, 1.4 token/s, 12 tokens
(TinyChatEngine) (base) elliotliu@hanlab-MSI-4090:~/TinyChatEngine/llm$ ./chat LLaMA2_7B_chat INT4 8
TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine
Using model: LLaMA2_7B_chat
Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq
Loading model... Finished!

USER: Hi
ASSISTANT:  Hello! How may I assist you today?


Inference latency, Total time: 8.7 s, 723.3 ms/token, 1.4 token/s, 12 tokens
  • No slowdown for dynamic dispatch CUDA or performance degradation

Neon

> ./chat
TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine
Using model: LLaMA_3_8B_Instruct
Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq
Loading model... Finished!

USER: Hi
ASSISTANT:  Hello! How can I assist you today?


Inference latency, Total time: 0.7 s, 75.1 ms/token, 13.3 token/s, 9 tokens
  • No slowdown or degradation for Neon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant