Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
shaper committed Dec 19, 2024
1 parent 2730e61 commit cdaa658
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 150 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
---
'@ai-sdk/google-vertex': patch
'ai': patch
---

feat (provider/google-vertex): Add imagen support.
19 changes: 12 additions & 7 deletions examples/ai-core/src/e2e/google-vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { vertex as vertexEdge } from '@ai-sdk/google-vertex/edge';
import { vertex as vertexNode } from '@ai-sdk/google-vertex';
import { z } from 'zod';
import {
detectImageMimeType,
generateText,
generateObject,
streamText,
Expand Down Expand Up @@ -462,7 +463,11 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
const { image } = await generateImage({
model,
prompt: 'A burrito launched through a tunnel',
size: '1024x1024',
providerOptions: {
vertex: {
aspectRatio: '3:4',
},
},
});

// Verify we got a Uint8Array back
Expand All @@ -472,10 +477,9 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
expect(image.uint8Array.length).toBeGreaterThan(10 * 1024);
expect(image.uint8Array.length).toBeLessThan(10 * 1024 * 1024);

// Verify PNG format by checking magic numbers
const pngSignature = [137, 80, 78, 71, 13, 10, 26, 10];
const actualSignature = Array.from(image.uint8Array.slice(0, 8));
expect(actualSignature).toEqual(pngSignature);
// Verify PNG format
const mimeType = detectImageMimeType(image.uint8Array);
expect(mimeType).toBe('image/png');

// Create a temporary buffer to verify image dimensions
const tempBuffer = Buffer.from(image.uint8Array);
Expand All @@ -484,8 +488,9 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
const width = tempBuffer.readUInt32BE(16);
const height = tempBuffer.readUInt32BE(20);

expect(width).toBe(1024);
expect(height).toBe(1024);
// https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#performance-limits
expect(width).toBe(896);
expect(height).toBe(1280);
});
});
},
Expand Down
3 changes: 1 addition & 2 deletions examples/ai-core/src/generate-image/google-vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ async function main() {
const { image } = await generateImage({
model: vertex.image('imagen-3.0-generate-001'),
prompt: 'A burrito launched through a tunnel',
size: '1024x1024',
providerOptions: {
vertex: {
// Vertex AI specific options if needed
aspectRatio: '16:9',
},
},
});
Expand Down
1 change: 1 addition & 0 deletions packages/ai/core/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ export * from './registry';
export * from './tool';
export * from './types';
export { cosineSimilarity } from './util/cosine-similarity';
export { detectImageMimeType } from './util/detect-image-mimetype';
190 changes: 107 additions & 83 deletions packages/google-vertex/src/google-vertex-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
import { GoogleVertexImageModel } from './google-vertex-image-model';
import { describe, it, expect, vi } from 'vitest';

const prompt = 'A cute baby sea otter';

Expand All @@ -9,104 +10,127 @@ const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
headers: { 'api-key': 'test-key' },
});

describe('doGenerate', () => {
const server = new JsonTestServer(
'https://api.example.com/models/imagen-3.0-generate-001:predict',
);

server.setupTestEnvironment();

function prepareJsonResponse() {
server.responseBodyJson = {
predictions: [
{ bytesBase64Encoded: 'base64-image-1' },
{ bytesBase64Encoded: 'base64-image-2' },
],
};
}

it('should pass the correct parameters', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt,
n: 2,
size: '1024x1024',
providerOptions: { customOption: { value: 123 } },
});
describe('GoogleVertexImageModel', () => {
describe('doGenerate', () => {
const server = new JsonTestServer(
'https://api.example.com/models/imagen-3.0-generate-001:predict',
);

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt }],
parameters: {
sampleCount: 2,
aspectRatio: '1:1',
customOption: { value: 123 },
},
server.setupTestEnvironment();

function prepareJsonResponse() {
server.responseBodyJson = {
predictions: [
{ bytesBase64Encoded: 'base64-image-1' },
{ bytesBase64Encoded: 'base64-image-2' },
],
};
}

it('should pass the correct parameters', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: { vertex: { aspectRatio: '1:1' } },
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt }],
parameters: {
sampleCount: 2,
aspectRatio: '1:1',
},
});
});
});

it('should pass headers', async () => {
prepareJsonResponse();
it('should pass headers', async () => {
prepareJsonResponse();

const modelWithHeaders = new GoogleVertexImageModel(
'imagen-3.0-generate-001',
{
provider: 'google-vertex',
baseURL: 'https://api.example.com',
headers: {
'Custom-Provider-Header': 'provider-header-value',
},
},
);

const modelWithHeaders = new GoogleVertexImageModel(
'imagen-3.0-generate-001',
{
provider: 'google-vertex',
baseURL: 'https://api.example.com',
await modelWithHeaders.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
headers: {
'Custom-Provider-Header': 'provider-header-value',
'Custom-Request-Header': 'request-header-value',
},
},
);

await modelWithHeaders.doGenerate({
prompt,
n: 2,
size: '1024x1024',
providerOptions: {},
headers: {
'Custom-Request-Header': 'request-header-value',
},
});
});

const requestHeaders = await server.getRequestHeaders();
const requestHeaders = await server.getRequestHeaders();

expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
'custom-provider-header': 'provider-header-value',
'custom-request-header': 'request-header-value',
expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
'custom-provider-header': 'provider-header-value',
'custom-request-header': 'request-header-value',
});
});
});

it('should extract the generated images', async () => {
prepareJsonResponse();

const result = await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
});
it('should extract the generated images', async () => {
prepareJsonResponse();

expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']);
});
const result = await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
});

it('should handle different aspect ratios', async () => {
prepareJsonResponse();
expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']);
});

await model.doGenerate({
prompt,
n: 1,
size: '1280x896',
providerOptions: {},
it('throws when size is specified', async () => {
const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
provider: 'vertex',
baseURL: 'https://example.com',
});

await expect(
model.doGenerate({
prompt: 'test prompt',
n: 1,
size: '1024x1024',
providerOptions: {},
}),
).rejects.toThrow(
'Google Vertex does not support the `size` option. Use `providerOptions.aspectRatio` instead.',
);
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt }],
parameters: {
sampleCount: 1,
aspectRatio: '4:3',
},
it('sends aspect ratio in the request', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt: 'test prompt',
n: 1,
size: undefined,
providerOptions: {
vertex: {
aspectRatio: '16:9',
},
},
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt: 'test prompt' }],
parameters: {
sampleCount: 1,
aspectRatio: '16:9',
},
});
});
});
});
Loading

0 comments on commit cdaa658

Please sign in to comment.