-
Notifications
You must be signed in to change notification settings - Fork 134
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add gemma.cpp bindings / llm-chain-gemma. (#281)
* Add gemma.cpp bindings / llm-chain-gemma. * style fixes * fix macos build * possibly fix windows build * minor fixes * style fix again * potential fix for windows * Uprev gemma.cpp version * exclude windows support * style fixes * update mio (unrelated but make the CI happy)
- Loading branch information
Showing
15 changed files
with
1,042 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "crates/llm-chain-llama/sys/llama.cpp"] | ||
path = crates/llm-chain-llama-sys/llama.cpp | ||
url = https://github.com/ggerganov/llama.cpp.git | ||
[submodule "crates/llm-chain-gemma-sys/gemma.cpp"] | ||
path = crates/llm-chain-gemma-sys/gemma.cpp | ||
url = https://github.com/google/gemma.cpp.git |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
[package] | ||
name = "llm-chain-gemma-sys" | ||
description = "A library with bindings for gemma.cpp" | ||
version = "0.1.0" | ||
edition = "2021" | ||
license = "MIT" | ||
keywords = ["llm", "langchain", "gemma", "chain"] | ||
categories = ["science"] | ||
authors = [ | ||
"Jun Mukai <[email protected]>", | ||
] | ||
repository = "https://github.com/sobelio/llm-chain/" | ||
readme = "README.md" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
|
||
[build-dependencies] | ||
cc = "1.0.87" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#![allow(clippy::uninlined_format_args)] | ||
|
||
extern crate cc; | ||
|
||
use std::env; | ||
|
||
fn main() { | ||
#[cfg(target_os = "windows")] | ||
{ | ||
// Gemma.cpp does not support MSBuild at this point -- | ||
// it does support clang-cl though. At this time, Windows | ||
// is out of the support because of this. | ||
// See: https://github.com/google/gemma.cpp/pull/6 | ||
cc::Build::new() | ||
.cpp(true) | ||
.file("src/bindings_win.cc") | ||
.std("c++17") | ||
.compile("bindings"); | ||
return; | ||
} | ||
let target = env::var("TARGET").unwrap(); | ||
// Link C++ standard library | ||
if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) { | ||
println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib); | ||
println!("cargo:rustc-link-arg=-l{}", cpp_stdlib); | ||
} | ||
// Link macOS Accelerate framework for matrix calculations | ||
if target.contains("apple") { | ||
println!("cargo:rustc-link-lib=framework=Accelerate"); | ||
} | ||
println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap()); | ||
println!("cargo:rustc-link-lib=static=gemma"); | ||
println!("cargo:rustc-link-lib=static=hwy"); | ||
println!("cargo:rustc-link-lib=static=hwy_contrib"); | ||
println!("cargo:rustc-link-lib=static=sentencepiece"); | ||
println!("cargo:rustc-link-lib=static=bindings"); | ||
println!("cargo:rerun-if-changed=wrapper.h"); | ||
|
||
// stop if we're on docs.rs | ||
if env::var("DOCS_RS").is_ok() { | ||
return; | ||
} | ||
|
||
// Run cmake to generate build files. | ||
env::set_current_dir("gemma.cpp").expect("Unable to change directory to gemma.cpp"); | ||
env::set_current_dir("build").expect("Unable to change directory to gemma.cpp build"); | ||
|
||
env::set_var("CXXFLAGS", "-fPIC"); | ||
env::set_var("CFLAGS", "-fPIC"); | ||
|
||
let mut code = std::process::Command::new("cmake"); | ||
let code = code | ||
.arg("..") | ||
.arg("-DCMAKE_BUILD_TYPE=Release") | ||
.arg("-DBUILD_SHARED_LIBS=OFF") | ||
.arg("-DWEIGHT_TYPE=hwy::bfloat16_t") | ||
.arg("-DSPM_ENABLE_SHARED=OFF"); | ||
let code = code.status().expect("Failed to generate build script"); | ||
if code.code() != Some(0) { | ||
panic!("Failed to generate build script"); | ||
} | ||
|
||
// Build binary. | ||
#[allow(clippy::suspicious_command_arg_space)] | ||
let code = std::process::Command::new("cmake") | ||
.arg("--build") | ||
.arg(".") | ||
.arg("--config") | ||
.arg("Release") | ||
.arg("--target") | ||
.arg("libgemma") | ||
.status() | ||
.expect("Failed to build lib"); | ||
if code.code() != Some(0) { | ||
panic!("Failed to build lib"); | ||
} | ||
|
||
// move libllama.a to where Cargo expects it (OUT_DIR) | ||
std::fs::copy( | ||
"libgemma.a", | ||
format!("{}/libgemma.a", env::var("OUT_DIR").unwrap()), | ||
) | ||
.expect("Failed to copy lib"); | ||
|
||
std::fs::copy( | ||
"_deps/highway-build/libhwy.a", | ||
format!("{}/libhwy.a", env::var("OUT_DIR").unwrap()), | ||
) | ||
.expect("Failed to copy libhwy.a"); | ||
|
||
std::fs::copy( | ||
"_deps/highway-build/libhwy_contrib.a", | ||
format!("{}/libhwy_contrib.a", env::var("OUT_DIR").unwrap()), | ||
) | ||
.expect("Failed to copy libhwy_contrib.a"); | ||
|
||
std::fs::copy( | ||
"_deps/sentencepiece-build/src/libsentencepiece.a", | ||
format!("{}/libsentencepiece.a", env::var("OUT_DIR").unwrap()), | ||
) | ||
.expect("Failed to copy libsentencepiece.a"); | ||
|
||
// Finally, build bindings.cc to allow access for gemma.cpp. | ||
// So far, bindgen does not correctly generate buildable rust file, | ||
// so I manually wrote bindings.rs for hand-written src/bindings.cc file. | ||
env::set_current_dir("..").expect("Unlable to change directory back to gemma.cpp"); | ||
env::set_current_dir("..").expect("Unlable to change directory back to crate top"); | ||
|
||
cc::Build::new() | ||
.cpp(true) | ||
.file("src/bindings.cc") | ||
.std("c++17") | ||
.include("./gemma.cpp") | ||
.include("./gemma.cpp/build/_deps/highway-src") | ||
.include("./gemma.cpp/build/_deps/sentencepiece-src") | ||
.compile("bindings"); | ||
} | ||
|
||
// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 | ||
fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { | ||
if target.contains("msvc") { | ||
None | ||
} else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { | ||
Some("c++") | ||
} else if target.contains("android") { | ||
Some("c++_shared") | ||
} else { | ||
Some("stdc++") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#include <gemma.h> | ||
|
||
extern "C" { | ||
|
||
hwy::ThreadPool* hwy_ThreadPool_ThreadPool(size_t num_threads) { | ||
return new hwy::ThreadPool(num_threads); | ||
} | ||
|
||
void hwy_ThreadPool_destructor(hwy::ThreadPool* pool) { | ||
delete pool; | ||
} | ||
|
||
gcpp::Gemma* gcpp_Gemma_Gemma( | ||
const char* tokenizer_path, size_t tokenizer_path_len, | ||
const char* compressed_weights_path, size_t compressed_weights_path_len, | ||
const char* weights_path, size_t weights_path_len, | ||
gcpp::Model model_type, hwy::ThreadPool* pool) { | ||
gcpp::Path tpath; | ||
tpath.path = std::string(tokenizer_path, tokenizer_path_len); | ||
gcpp::Path cwpath; | ||
cwpath.path = std::string(compressed_weights_path, compressed_weights_path_len); | ||
gcpp::Path wpath; | ||
wpath.path = std::string(weights_path, weights_path_len); | ||
return new gcpp::Gemma(tpath, cwpath, wpath, model_type, *pool); | ||
} | ||
|
||
void gcpp_Gemma_destructor(gcpp::Gemma* gemma) { | ||
delete gemma; | ||
} | ||
|
||
void gcpp_Gemma_SetModelTraining(gcpp::Gemma* gemma, gcpp::ModelTraining training) { | ||
gemma->model_training = training; | ||
} | ||
|
||
gcpp::KVCache* gcpp_CreateKVCache(gcpp::Model model_type) { | ||
gcpp::KVCache* cache = new gcpp::KVCache{}; | ||
*cache = gcpp::CreateKVCache(model_type); | ||
return cache; | ||
} | ||
|
||
void gcpp_KVCache_destructor(gcpp::KVCache* kvcache) { | ||
delete kvcache; | ||
} | ||
|
||
std::vector<int>* std_vector_int_vector() { | ||
return new std::vector<int>(); | ||
} | ||
|
||
void std_vector_int_destructor(std::vector<int>* v) { | ||
delete v; | ||
} | ||
|
||
size_t std_vector_int_size(const std::vector<int>* v) { | ||
return v->size(); | ||
} | ||
|
||
int std_vector_int_at(const std::vector<int>* v, size_t i) { | ||
return v->at(i); | ||
} | ||
|
||
std::string* std_string_string() { | ||
return new std::string(); | ||
} | ||
|
||
void std_string_destructor(std::string* s) { | ||
delete s; | ||
} | ||
|
||
const char* std_string_c_str(const std::string* s) { | ||
return s->c_str(); | ||
} | ||
|
||
bool gcpp_Gemma_Encode(gcpp::Gemma* gemma, const char* input, size_t len, std::vector<int>* out) { | ||
return gemma->Tokenizer()->Encode(std::string(input, len), out).ok(); | ||
} | ||
|
||
bool gcpp_Gemma_Decode(gcpp::Gemma* gemma, int token, std::string* out) { | ||
return gemma->Tokenizer()->Decode(std::vector<int>{token}, out).ok(); | ||
} | ||
|
||
bool gcpp_Gemma_Decodes(gcpp::Gemma* gemma, const int* tokens, int num_tokens, std::string* out) { | ||
std::vector<int> v; | ||
v.reserve(num_tokens); | ||
for (int i = 0; i < num_tokens; i++) { | ||
v.push_back(tokens[i]); | ||
} | ||
return gemma->Tokenizer()->Decode(v, out).ok(); | ||
} | ||
|
||
std::mt19937* std_mt19937_mt19937() { | ||
return new std::mt19937(); | ||
} | ||
|
||
void std_mt19937_destructor(std::mt19937* gen) { | ||
delete gen; | ||
} | ||
|
||
void std_mt19937_seed(std::mt19937* gen, int seed) { | ||
gen->seed(seed); | ||
} | ||
|
||
void std_mt19937_random_seed(std::mt19937* gen) { | ||
std::random_device rd; | ||
gen->seed(rd()); | ||
} | ||
|
||
typedef bool (*stream_callback)(void*, int, float); | ||
typedef bool (*accept_callback)(void*, int); | ||
|
||
void gcpp_GenerateGemma( | ||
gcpp::Gemma* gemma, const gcpp::RuntimeConfig* config, | ||
const std::vector<int>* prompt, size_t start_pos, | ||
gcpp::KVCache* kvcache, hwy::ThreadPool* pool, | ||
void* stream_context, stream_callback stream_token, | ||
std::mt19937* gen) { | ||
gcpp::GenerateGemma( | ||
*gemma, *config, *prompt, start_pos, *kvcache, *pool, | ||
[&stream_context, &stream_token](int token, float value) { | ||
return stream_token(stream_context, token, value); | ||
}, | ||
*gen); | ||
} | ||
|
||
} |
Oops, something went wrong.