-
Notifications
You must be signed in to change notification settings - Fork 412
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Test only] multimodal android binding
ghstack-source-id: 2044005547881d324b65a7a6880f5ed541242a92 Pull Request resolved: #4426
- Loading branch information
1 parent
c937658
commit 0820e9c
Showing
6 changed files
with
347 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <cassert> | ||
#include <chrono> | ||
#include <iostream> | ||
#include <memory> | ||
#include <sstream> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include <executorch/examples/models/llava/runner/multimodal_runner.h> | ||
#include <executorch/runtime/platform/log.h> | ||
#include <executorch/runtime/platform/platform.h> | ||
#include <executorch/runtime/platform/runtime.h> | ||
|
||
#if defined(ET_USE_THREADPOOL) | ||
#include <executorch/backends/xnnpack/threadpool/cpuinfo_utils.h> | ||
#include <executorch/backends/xnnpack/threadpool/threadpool.h> | ||
#endif | ||
|
||
#include <fbjni/ByteBuffer.h> | ||
#include <fbjni/fbjni.h> | ||
|
||
#ifdef __ANDROID__ | ||
#include <android/log.h> | ||
|
||
// For Android, write to logcat | ||
void et_pal_emit_log_message( | ||
et_timestamp_t timestamp, | ||
et_pal_log_level_t level, | ||
const char* filename, | ||
const char* function, | ||
size_t line, | ||
const char* message, | ||
size_t length) { | ||
int android_log_level = ANDROID_LOG_UNKNOWN; | ||
if (level == 'D') { | ||
android_log_level = ANDROID_LOG_DEBUG; | ||
} else if (level == 'I') { | ||
android_log_level = ANDROID_LOG_INFO; | ||
} else if (level == 'E') { | ||
android_log_level = ANDROID_LOG_ERROR; | ||
} else if (level == 'F') { | ||
android_log_level = ANDROID_LOG_FATAL; | ||
} | ||
|
||
__android_log_print(android_log_level, "MULTIMODAL", "%s", message); | ||
} | ||
#endif | ||
|
||
using namespace torch::executor; | ||
|
||
namespace executorch_jni { | ||
|
||
class ExecuTorchMultiModalCallbackJni | ||
: public facebook::jni::JavaClass<ExecuTorchMultiModalCallbackJni> { | ||
public: | ||
constexpr static const char* kJavaDescriptor = | ||
"Lorg/pytorch/executorch/MultiModalCallback;"; | ||
|
||
void onResult(std::string result) const { | ||
static auto cls = ExecuTorchMultiModalCallbackJni::javaClassStatic(); | ||
static const auto method = | ||
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult"); | ||
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result); | ||
method(self(), s); | ||
} | ||
|
||
void onStats(const MultiModalRunner::Stats& result) const { | ||
static auto cls = ExecuTorchMultiModalCallbackJni::javaClassStatic(); | ||
static const auto method = cls->getMethod<void(jfloat)>("onStats"); | ||
double eval_time = | ||
(double)(result.inference_end_ms - result.prompt_eval_end_ms); | ||
|
||
float tps = result.num_generated_tokens / eval_time * | ||
result.SCALING_FACTOR_UNITS_PER_SECOND; | ||
|
||
method(self(), tps); | ||
} | ||
}; | ||
|
||
class ExecuTorchMultiModalJni | ||
: public facebook::jni::HybridClass<ExecuTorchMultiModalJni> { | ||
private: | ||
friend HybridBase; | ||
std::unique_ptr<MultiModalRunner> runner_; | ||
|
||
public: | ||
constexpr static auto kJavaDescriptor = | ||
"Lorg/pytorch/executorch/MultiModalModule;"; | ||
|
||
static facebook::jni::local_ref<jhybriddata> initHybrid( | ||
facebook::jni::alias_ref<jclass>, | ||
facebook::jni::alias_ref<jstring> model_path, | ||
facebook::jni::alias_ref<jstring> tokenizer_path, | ||
jfloat temperature) { | ||
return makeCxxInstance(model_path, tokenizer_path, temperature); | ||
} | ||
|
||
ExecuTorchMultiModalJni( | ||
facebook::jni::alias_ref<jstring> model_path, | ||
facebook::jni::alias_ref<jstring> tokenizer_path, | ||
jfloat temperature) { | ||
#if defined(ET_USE_THREADPOOL) | ||
// Reserve 1 thread for the main thread. | ||
uint32_t num_performant_cores = | ||
torch::executorch::cpuinfo::get_num_performant_cores() - 1; | ||
if (num_performant_cores > 0) { | ||
ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); | ||
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool( | ||
num_performant_cores); | ||
} | ||
#endif | ||
|
||
runner_ = std::make_unique<MultiModalRunner>( | ||
model_path->toStdString().c_str(), | ||
tokenizer_path->toStdString().c_str(), | ||
temperature); | ||
} | ||
|
||
jint generate( | ||
facebook::jni::alias_ref<jintArray> image, | ||
jint width, | ||
jint height, | ||
jint channels, | ||
facebook::jni::alias_ref<jstring> prompt, | ||
jint startPos, | ||
facebook::jni::alias_ref<ExecuTorchMultiModalCallbackJni> callback) { | ||
auto image_size = image->size(); | ||
std::vector<Image> images; | ||
if (image_size != 0) { | ||
std::vector<jint> image_data_jint(image_size); | ||
std::vector<uint8_t> image_data(image_size); | ||
image->getRegion(0, image_size, image_data_jint.data()); | ||
for (int i = 0; i < image_size; i++) { | ||
image_data[i] = image_data_jint[i]; | ||
} | ||
Image image_runner{image_data, width, height, channels}; | ||
images.push_back(image_runner); | ||
} | ||
runner_->generate( | ||
images, | ||
prompt->toStdString(), | ||
1024, | ||
[callback](std::string result) { callback->onResult(result); }, | ||
[callback](const MultiModalRunner::Stats& result) { | ||
callback->onStats(result); | ||
}); | ||
return 0; | ||
} | ||
|
||
void stop() { | ||
runner_->stop(); | ||
} | ||
|
||
jint load() { | ||
return static_cast<jint>(runner_->load()); | ||
} | ||
|
||
static void registerNatives() { | ||
registerHybrid({ | ||
makeNativeMethod("initHybrid", ExecuTorchMultiModalJni::initHybrid), | ||
makeNativeMethod("generate", ExecuTorchMultiModalJni::generate), | ||
makeNativeMethod("stop", ExecuTorchMultiModalJni::stop), | ||
makeNativeMethod("load", ExecuTorchMultiModalJni::load), | ||
}); | ||
} | ||
}; | ||
|
||
} // namespace executorch_jni | ||
|
||
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { | ||
return facebook::jni::initialize( | ||
vm, [] { executorch_jni::ExecuTorchMultiModalJni::registerNatives(); }); | ||
} |
30 changes: 30 additions & 0 deletions
30
extension/android/src/main/java/org/pytorch/executorch/MultiModalCallback.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
package org.pytorch.executorch; | ||
|
||
import com.facebook.jni.annotations.DoNotStrip; | ||
|
||
public interface MultiModalCallback { | ||
/** | ||
* Called when a new result is available from JNI. Users will keep getting onResult() invocations | ||
* until generate() finishes. | ||
* | ||
* @param result Last generated token | ||
*/ | ||
@DoNotStrip | ||
public void onResult(String result); | ||
|
||
/** | ||
* Called when the statistics for the generate() is available. | ||
* | ||
* @param tps Tokens/second for generated tokens. | ||
*/ | ||
@DoNotStrip | ||
public void onStats(float tps); | ||
} |
55 changes: 55 additions & 0 deletions
55
extension/android/src/main/java/org/pytorch/executorch/MultiModalModule.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
package org.pytorch.executorch; | ||
|
||
import com.facebook.jni.HybridData; | ||
import com.facebook.jni.annotations.DoNotStrip; | ||
import com.facebook.soloader.nativeloader.NativeLoader; | ||
import com.facebook.soloader.nativeloader.SystemDelegate; | ||
|
||
public class MultiModalModule { | ||
static { | ||
if (!NativeLoader.isInitialized()) { | ||
NativeLoader.init(new SystemDelegate()); | ||
} | ||
NativeLoader.loadLibrary("executorch_multimodal_jni"); | ||
} | ||
|
||
private final HybridData mHybridData; | ||
|
||
@DoNotStrip | ||
private static native HybridData initHybrid( | ||
String modulePath, String tokenizerPath, float temperature); | ||
|
||
/** Constructs a MultiModal Module for a model with given path, tokenizer, and temperature. */ | ||
public MultiModalModule(String modulePath, String tokenizerPath, float temperature) { | ||
mHybridData = initHybrid(modulePath, tokenizerPath, temperature); | ||
} | ||
|
||
public void resetNative() { | ||
mHybridData.resetNative(); | ||
} | ||
|
||
/** | ||
* Start generating tokens from the module. | ||
* | ||
* @param prompt Input prompt | ||
* @param MultiModalCallback callback object to receive results. | ||
*/ | ||
@DoNotStrip | ||
public native int generate(int[] image, int width, int height, int channels, String prompt, int startPos, MultiModalCallback MultiModalCallback); | ||
|
||
/** Stop current generate() before it finishes. */ | ||
@DoNotStrip | ||
public native void stop(); | ||
|
||
/** Force loading the module. Otherwise the model is loaded during first generate(). */ | ||
@DoNotStrip | ||
public native int load(); | ||
} |