Skip to content

Commit

Permalink
Align API
Browse files Browse the repository at this point in the history
  • Loading branch information
vishniakov-nikolai committed Dec 12, 2024
1 parent b3eba5e commit 6f0e0b4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 78 deletions.
56 changes: 0 additions & 56 deletions samples/js/interactive.js

This file was deleted.

20 changes: 10 additions & 10 deletions src/js/lib/module.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ class LLMPipeline {
return result;
}

async generate(prompt, generationCallback, options = {}) {
async generate(prompt, generationCallbackOrOptions, generationCallback) {
let options = {};

if (!generationCallback)
generationCallback = generationCallbackOrOptions;
else
options = generationCallbackOrOptions;

if (!this.isInitialized)
throw new Error('Pipeline is not initialized');

Expand Down Expand Up @@ -95,24 +102,17 @@ class LLMPipeline {
}
}

const availablePipelines = { LLMPipeline: LLMPipeline };

class Pipeline {
static async create(pipelineType, modelPath, device = 'CPU') {
if (!Object.keys(availablePipelines).includes(pipelineType))
throw new Error(`Pipeline type: '${pipelineType}' doesn't support`);

const pipeline = new availablePipelines[pipelineType](modelPath, device);
static async LLMPipeline(modelPath, device = 'CPU') {
const pipeline = new LLMPipeline(modelPath, device);
await pipeline.init();

return pipeline;
}
}

const availablePipelinesKeys = Object.keys(availablePipelines);

export {
addon,
Pipeline,
availablePipelinesKeys as availablePipelines,
};
24 changes: 12 additions & 12 deletions src/js/tests/module.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ describe('module', async () => {
let pipeline = null;

await before(async () => {
pipeline = await Pipeline.create('LLMPipeline', MODEL_PATH, 'AUTO');
pipeline = await Pipeline.LLMPipeline(MODEL_PATH, 'AUTO');

await pipeline.startChat();
});
Expand All @@ -23,8 +23,8 @@ describe('module', async () => {
await it('should generate "Hello world"', async () => {
const result = await pipeline.generate(
'Type "Hello world!" in English',
{ temperature: '0', max_new_tokens: '4' },
() => {},
{ temperature: '0', max_new_tokens: '4' }
);

assert.strictEqual(result, 'Hello world!');
Expand All @@ -33,7 +33,7 @@ describe('module', async () => {

describe('corner cases', async () => {
it('should throw an error if pipeline is already initialized', async () => {
const pipeline = await Pipeline.create('LLMPipeline', MODEL_PATH, 'AUTO');
const pipeline = await Pipeline.LLMPipeline(MODEL_PATH, 'AUTO');

await assert.rejects(
async () => await pipeline.init(),
Expand All @@ -45,7 +45,7 @@ describe('corner cases', async () => {
});

it('should throw an error if chat is already started', async () => {
const pipeline = await Pipeline.create('LLMPipeline', MODEL_PATH, 'AUTO');
const pipeline = await Pipeline.LLMPipeline(MODEL_PATH, 'AUTO');

await pipeline.startChat();

Expand All @@ -59,7 +59,7 @@ describe('corner cases', async () => {
});

it('should throw an error if chat is not started', async () => {
const pipeline = await Pipeline.create('LLMPipeline', MODEL_PATH, 'AUTO');
const pipeline = await Pipeline.LLMPipeline(MODEL_PATH, 'AUTO');

await assert.rejects(
() => pipeline.finishChat(),
Expand All @@ -75,7 +75,7 @@ describe('generation parameters validation', () => {
let pipeline = null;

before(async () => {
pipeline = await Pipeline.create('LLMPipeline', MODEL_PATH, 'AUTO');
pipeline = await Pipeline.LLMPipeline(MODEL_PATH, 'AUTO');

await pipeline.startChat();
});
Expand All @@ -95,7 +95,7 @@ describe('generation parameters validation', () => {
});

it('should throw an error if generationCallback is not a function', async () => {
const pipeline = await Pipeline.create('LLMPipeline', MODEL_PATH, 'AUTO');
const pipeline = await Pipeline.LLMPipeline(MODEL_PATH, 'AUTO');

await pipeline.startChat();

Expand All @@ -110,7 +110,7 @@ describe('generation parameters validation', () => {

it('should throw an error if options specified but not an object', async () => {
await assert.rejects(
async () => await pipeline.generate('prompt', () => {}, 'options'),
async () => await pipeline.generate('prompt', 'options', () => {}),
{
name: 'Error',
message: 'Options must be an object',
Expand All @@ -120,7 +120,7 @@ describe('generation parameters validation', () => {

it('should perform generation with default options', async () => {
try {
await pipeline.generate('prompt', () => {}, { max_new_tokens: 1 });
await pipeline.generate('prompt', { max_new_tokens: 1 }, () => {});
} catch (error) {
assert.fail(error);
}
Expand All @@ -129,14 +129,14 @@ describe('generation parameters validation', () => {
});

it('should return a string as generation result', async () => {
const reply = await pipeline.generate('prompt', () => {}, { max_new_tokens: 1 });
const reply = await pipeline.generate('prompt', { max_new_tokens: 1 }, () => {});

assert.strictEqual(typeof reply, 'string');
});

it('should call generationCallback with string chunk', async () => {
await pipeline.generate('prompt', (chunk) => {
await pipeline.generate('prompt', { max_new_tokens: 1 }, (chunk) => {
assert.strictEqual(typeof chunk, 'string');
}, { max_new_tokens: 1 });
});
});
});

0 comments on commit 6f0e0b4

Please sign in to comment.