Skip to content

Commit

Permalink
Move the stream subscription stuff to the API client, expose it throu…
Browse files Browse the repository at this point in the history
…gh `runs.fetchStream`
  • Loading branch information
ericallam committed Dec 19, 2024
1 parent 01c2ca9 commit 85b3605
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 94 deletions.
18 changes: 18 additions & 0 deletions packages/core/src/v3/apiClient/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import {
RunSubscription,
TaskRunShape,
runShapeStream,
SSEStreamSubscriptionFactory,
} from "./runStream.js";
import {
CreateEnvironmentVariableParams,
Expand Down Expand Up @@ -681,6 +682,23 @@ export class ApiClient {
});
}

async fetchStream<T>(
runId: string,
streamKey: string,
options?: { signal?: AbortSignal; baseUrl?: string }
): Promise<AsyncIterableStream<T>> {
const streamFactory = new SSEStreamSubscriptionFactory(options?.baseUrl ?? this.baseUrl, {
headers: this.getHeaders(),
signal: options?.signal,
});

const subscription = streamFactory.createSubscription(runId, streamKey);

const stream = await subscription.subscribe();

return stream as AsyncIterableStream<T>;
}

async generateJWTClaims(requestOptions?: ZodFetchOptions): Promise<Record<string, any>> {
return zodfetch(
z.record(z.any()),
Expand Down
84 changes: 4 additions & 80 deletions packages/core/src/v3/apiClient/runStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,7 @@ export function runShapeStream<TRunTypes extends AnyRunTypes>(
): RunSubscription<TRunTypes> {
const abortController = new AbortController();

const version1 = new SSEStreamSubscriptionFactory(
getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev",
{
headers: options?.headers,
signal: abortController.signal,
}
);

const version2 = new ElectricStreamSubscriptionFactory(
const streamFactory = new SSEStreamSubscriptionFactory(
getEnvVar("TRIGGER_STREAM_URL", getEnvVar("TRIGGER_API_URL")) ?? "https://api.trigger.dev",
{
headers: options?.headers,
Expand All @@ -124,7 +116,7 @@ export function runShapeStream<TRunTypes extends AnyRunTypes>(
const $options: RunSubscriptionOptions = {
runShapeStream: runStreamInstance.stream,
stopRunShapeStream: () => runStreamInstance.stop(30 * 1000),
streamFactory: new VersionedStreamSubscriptionFactory(version1, version2),
streamFactory: streamFactory,
abortController,
...options,
};
Expand All @@ -138,12 +130,7 @@ export interface StreamSubscription {
}

export interface StreamSubscriptionFactory {
createSubscription(
metadata: Record<string, unknown>,
runId: string,
streamKey: string,
baseUrl?: string
): StreamSubscription;
createSubscription(runId: string, streamKey: string, baseUrl?: string): StreamSubscription;
}

// Real implementation for production
Expand Down Expand Up @@ -194,12 +181,7 @@ export class SSEStreamSubscriptionFactory implements StreamSubscriptionFactory {
private options: { headers?: Record<string, string>; signal?: AbortSignal }
) {}

createSubscription(
metadata: Record<string, unknown>,
runId: string,
streamKey: string,
baseUrl?: string
): StreamSubscription {
createSubscription(runId: string, streamKey: string, baseUrl?: string): StreamSubscription {
if (!runId || !streamKey) {
throw new Error("runId and streamKey are required");
}
Expand Down Expand Up @@ -238,63 +220,6 @@ export class ElectricStreamSubscription implements StreamSubscription {
}
}

export class ElectricStreamSubscriptionFactory implements StreamSubscriptionFactory {
constructor(
private baseUrl: string,
private options: { headers?: Record<string, string>; signal?: AbortSignal }
) {}

createSubscription(
metadata: Record<string, unknown>,
runId: string,
streamKey: string,
baseUrl?: string
): StreamSubscription {
if (!runId || !streamKey) {
throw new Error("runId and streamKey are required");
}

return new ElectricStreamSubscription(
`${baseUrl ?? this.baseUrl}/realtime/v2/streams/${runId}/${streamKey}`,
this.options
);
}
}

export class VersionedStreamSubscriptionFactory implements StreamSubscriptionFactory {
constructor(
private version1: StreamSubscriptionFactory,
private version2: StreamSubscriptionFactory
) {}

createSubscription(
metadata: Record<string, unknown>,
runId: string,
streamKey: string,
baseUrl?: string
): StreamSubscription {
if (!runId || !streamKey) {
throw new Error("runId and streamKey are required");
}

const version =
typeof metadata.$$streamsVersion === "string" ? metadata.$$streamsVersion : "v1";

const $baseUrl =
typeof metadata.$$streamsBaseUrl === "string" ? metadata.$$streamsBaseUrl : baseUrl;

if (version === "v1") {
return this.version1.createSubscription(metadata, runId, streamKey, $baseUrl);
}

if (version === "v2") {
return this.version2.createSubscription(metadata, runId, streamKey, $baseUrl);
}

throw new Error(`Unknown stream version: ${version}`);
}
}

export interface RunShapeProvider {
onShape(callback: (shape: SubscribeRunRawShape) => Promise<void>): Promise<() => void>;
}
Expand Down Expand Up @@ -385,7 +310,6 @@ export class RunSubscription<TRunTypes extends AnyRunTypes> {
activeStreams.add(streamKey);

const subscription = this.options.streamFactory.createSubscription(
run.metadata,
run.id,
streamKey,
this.options.client?.baseUrl
Expand Down
19 changes: 5 additions & 14 deletions packages/core/src/v3/runMetadata/manager.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import { JSONHeroPath } from "@jsonhero/path";
import { dequal } from "dequal/lite";
import { DeserializedJson } from "../../schemas/json.js";
import { ApiRequestOptions } from "../zodfetch.js";
import { RunMetadataManager, RunMetadataUpdater } from "./types.js";
import { MetadataStream } from "./metadataStream.js";
import { ApiClient } from "../apiClient/index.js";
import { AsyncIterableStream } from "../apiClient/stream.js";
import { FlushedRunMetadata, RunMetadataChangeOperation } from "../schemas/common.js";
import { ApiRequestOptions } from "../zodfetch.js";
import { MetadataStream } from "./metadataStream.js";
import { applyMetadataOperations } from "./operations.js";
import { SSEStreamSubscriptionFactory } from "../apiClient/runStream.js";
import { AsyncIterableStream } from "../apiClient/stream.js";
import { RunMetadataManager, RunMetadataUpdater } from "./types.js";

const MAXIMUM_ACTIVE_STREAMS = 5;
const MAXIMUM_TOTAL_STREAMS = 10;
Expand Down Expand Up @@ -208,14 +206,7 @@ export class StandardMetadataManager implements RunMetadataManager {

const $baseUrl = typeof baseUrl === "string" ? baseUrl : this.streamsBaseUrl;

const streamFactory = new SSEStreamSubscriptionFactory($baseUrl, {
headers: this.apiClient.getHeaders(),
signal,
});

const subscription = streamFactory.createSubscription(this.store ?? {}, this.runId, key);

return (await subscription.subscribe()) as AsyncIterableStream<T>;
return this.apiClient.fetchStream<T>(this.runId, key, { baseUrl: $baseUrl, signal });
}

private async doStream<T>(
Expand Down
11 changes: 11 additions & 0 deletions packages/trigger-sdk/src/v3/runs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type {
RunSubscription,
TaskRunShape,
AnyBatchedRunHandle,
AsyncIterableStream,
} from "@trigger.dev/core/v3";
import {
ApiPromise,
Expand Down Expand Up @@ -51,6 +52,7 @@ export const runs = {
subscribeToRun,
subscribeToRunsWithTag,
subscribeToBatch: subscribeToRunsInBatch,
fetchStream,
};

export type ListRunsItem = ListRunResponseItem;
Expand Down Expand Up @@ -465,3 +467,12 @@ function subscribeToRunsInBatch<TTasks extends AnyTask>(

return apiClient.subscribeToBatch<InferRunTypes<TTasks>>(batchId);
}

/**
* Fetches a stream of data from a run's stream key.
*/
async function fetchStream<T>(runId: string, streamKey: string): Promise<AsyncIterableStream<T>> {
const apiClient = apiClientManager.clientOrThrow();

return await apiClient.fetchStream(runId, streamKey);
}

0 comments on commit 85b3605

Please sign in to comment.