diff --git a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java index 5fb644b08f..cbd9db84a6 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java @@ -9,12 +9,14 @@ package com.example.executorchllamademo; import static junit.framework.TestCase.assertTrue; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import android.os.Bundle; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.platform.app.InstrumentationRegistry; -import android.os.Bundle; + import java.io.File; import java.util.ArrayList; import java.util.Arrays; @@ -39,23 +41,24 @@ public void testTokensPerSecond() { // Find out the model name File directory = new File(RESOURCE_PATH); Arrays.stream(directory.listFiles()) - .filter(file -> file.getName().endsWith(".pte") || file.getName().endsWith(".pt")) - .forEach(model -> { - LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); - // Print the model name because there might be more than one of them - report("ModelName", model.getName()); - - int loadResult = mModule.load(); - // Check that the model can be load successfully - assertEquals(0, loadResult); - - // Run a testing prompt - mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); - assertFalse(tokensPerSecond.isEmpty()); - - final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); - reportMetric("TPS", tps); - }); + .filter(file -> file.getName().endsWith(".pte") || file.getName().endsWith(".pt")) + .forEach( + model -> { + LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); + // Print the model name because there might be more than one of them + report("ModelName", model.getName()); + + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(0, loadResult); + + // Run a testing prompt + mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); + assertFalse(tokensPerSecond.isEmpty()); + + final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); + reportMetric("TPS", tps); + }); } @Override