diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ab1d859..0b5e894 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,3 +32,26 @@ This will: - Push the commit and tag to GitHub - Publish the package to npm - Create a GitHub release + +## Vendored Dependencies + +We have a few dependencies that have been bundled into the vendor directory rather than adding external npm dependencies. + +These have been generated using bundlejs.com and copied into the appropriate directory along with the license and repository information. + +* [eventsource-parser/stream](https://bundlejs.com/?bundle&q=eventsource-parser%40latest%2Fstream&config=%7B%22esbuild%22%3A%7B%22format%22%3A%22cjs%22%2C%22minify%22%3Afalse%2C%22platform%22%3A%22neutral%22%7D%7D) +* [streams-text-encoding/text-decoder-stream](https://bundlejs.com/?q=%40stardazed%2Fstreams-text-encoding&treeshake=%5B%7B+TextDecoderStream+%7D%5D&config=%7B%22esbuild%22%3A%7B%22format%22%3A%22cjs%22%2C%22minify%22%3Afalse%7D%7D) + +> [!NOTE] +> The vendored implementation of `TextDecoderStream` requires +> the following patch to be applied to the output of bundlejs.com: +> +> ```diff +> constructor(label, options) { +> - this[decDecoder] = new TextDecoder(label, options); +> - this[decTransform] = new TransformStream(new TextDecodeTransformer(this[decDecoder])); +> + const decoder = new TextDecoder(label || "utf-8", options || {}); +> + this[decDecoder] = decoder; +> + this[decTransform] = new TransformStream(new TextDecodeTransformer(decoder)); +> } +> ``` diff --git a/README.md b/README.md index f6d798c..29c31c6 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ and everything else you can do with ## Installation +This library requires Node.js >= 18. + Install it from npm: ```bash @@ -20,7 +22,7 @@ npm install replicate ## Usage -Create the client: +Import or require the package: ```js // CommonJS (default or using .cjs extension) @@ -30,9 +32,11 @@ const Replicate = require("replicate"); import Replicate from "replicate"; ``` -``` +Instantiate the client: + +```js const replicate = new Replicate({ - // get your token from https://replicate.com/account + // get your token from https://replicate.com/account/api-tokens auth: "my api token", // defaults to process.env.REPLICATE_API_TOKEN }); ``` @@ -148,8 +152,53 @@ await replicate.predictions.create({ // => {"id": "xyz", "status": "successful", ... } ``` +## Verifying webhooks + +To prevent unauthorized requests, Replicate signs every webhook and its metadata with a unique key for each user or organization. You can use this signature to verify the webhook indeed comes from Replicate before you process it. + +This client includes a `validateWebhook` convenience function that you can use to validate webhooks. + +To validate webhooks: + +1. Check out the [webhooks guide](https://replicate.com/docs/webhooks) to get started. +1. [Retrieve your webhook signing secret](https://replicate.com/docs/webhooks#retrieving-the-webhook-signing-key) and store it in your enviroment. +1. Update your webhook handler to call `validateWebhook(request, secret)`, where `request` is an instance of a [web-standard `Request` object](https://developer.mozilla.org/en-US/docs/Web/API/object, and `secret` is the signing secret for your environment. + +Here's an example of how to validate webhooks using Next.js: + +```js +import { NextResponse } from 'next/server'; +import { validateWebhook } from 'replicate'; + +export async function POST(request) { + const secret = process.env.REPLICATE_WEBHOOK_SIGNING_SECRET; + + if (!secret) { + console.log("Skipping webhook validation. To validate webhooks, set REPLICATE_WEBHOOK_SIGNING_SECRET") + const body = await request.json(); + console.log(body); + return NextResponse.json({ detail: "Webhook received (but not validated)" }, { status: 200 }); + } + + const webhookIsValid = await validateWebhook(request.clone(), secret); + + if (!webhookIsValid) { + return NextResponse.json({ detail: "Webhook is invalid" }, { status: 401 }); + } + + // process validated webhook here... + console.log("Webhook is valid!"); + const body = await request.json(); + console.log(body); + + return NextResponse.json({ detail: "Webhook is valid" }, { status: 200 }); +} +``` + ## TypeScript +The `Replicate` constructor and all `replicate.*` methods are fully typed. + Currently in order to support the module format used by `replicate` you'll need to set `esModuleInterop` to `true` in your tsconfig.json. ## API @@ -973,29 +1022,17 @@ The `replicate.request()` method is used by the other methods to interact with the Replicate API. You can call this method directly to make other requests to the API. -## TypeScript - -The `Replicate` constructor and all `replicate.*` methods are fully typed. - -## Vendored Dependencies +## Troubleshooting -We have a few dependencies that have been bundled into the vendor directory rather than adding external npm dependencies. +### Predictions hanging in Next.js -These have been generated using bundlejs.com and copied into the appropriate directory along with the license and repository information. +Next.js App Router adds some extensions to `fetch` to make it cache responses. To disable this behavior, set the `cache` option to `"no-store"` on the Replicate client's fetch object: -* [eventsource-parser/stream](https://bundlejs.com/?bundle&q=eventsource-parser%40latest%2Fstream&config=%7B%22esbuild%22%3A%7B%22format%22%3A%22cjs%22%2C%22minify%22%3Afalse%2C%22platform%22%3A%22neutral%22%7D%7D) -* [streams-text-encoding/text-decoder-stream](https://bundlejs.com/?q=%40stardazed%2Fstreams-text-encoding&treeshake=%5B%7B+TextDecoderStream+%7D%5D&config=%7B%22esbuild%22%3A%7B%22format%22%3A%22cjs%22%2C%22minify%22%3Afalse%7D%7D) +```js +replicate = new Replicate({/*...*/}) +replicate.fetch = (url, options) => { + return fetch(url, { ...options, cache: "no-store" }); +}; +``` -> [!NOTE] -> The vendored implementation of `TextDecoderStream` requires -> the following patch to be applied to the output of bundlejs.com: -> -> ```diff -> constructor(label, options) { -> - this[decDecoder] = new TextDecoder(label, options); -> - this[decTransform] = new TransformStream(new TextDecodeTransformer(this[decDecoder])); -> + const decoder = new TextDecoder(label || "utf-8", options || {}); -> + this[decDecoder] = decoder; -> + this[decTransform] = new TransformStream(new TextDecodeTransformer(decoder)); -> } -> ``` +Alternatively you can use Next.js [`noStore`](https://github.com/replicate/replicate-javascript/issues/136#issuecomment-1847442879) to opt out of caching for your component. diff --git a/biome.json b/biome.json index ecb665f..094cf0e 100644 --- a/biome.json +++ b/biome.json @@ -1,7 +1,11 @@ { "$schema": "https://biomejs.dev/schemas/1.0.0/schema.json", "files": { - "ignore": [".wrangler", "vendor/*"] + "ignore": [ + ".wrangler", + "node_modules", + "vendor/*" + ] }, "formatter": { "indentStyle": "space", diff --git a/index.d.ts b/index.d.ts index 31a2325..1ef9e89 100644 --- a/index.d.ts +++ b/index.d.ts @@ -39,6 +39,21 @@ declare module "replicate" { }; } + export interface FileObject { + id: string; + name: string; + content_type: string; + size: number; + etag: string; + checksum: string; + metadata: Record; + created_at: string; + expires_at: string | null; + urls: { + get: string; + }; + } + export interface Hardware { sku: string; name: string; @@ -93,6 +108,8 @@ declare module "replicate" { export type Training = Prediction; + export type FileEncodingStrategy = "default" | "upload" | "data-uri"; + export interface Page { previous?: string; next?: string; @@ -119,12 +136,14 @@ declare module "replicate" { input: Request | string, init?: RequestInit ) => Promise; + fileEncodingStrategy?: FileEncodingStrategy; }); auth: string; userAgent?: string; baseUrl?: string; fetch: (input: Request | string, init?: RequestInit) => Promise; + fileEncodingStrategy: FileEncodingStrategy; run( identifier: `${string}/${string}` | `${string}/${string}:${string}`, diff --git a/index.js b/index.js index dd60cec..bf53449 100644 --- a/index.js +++ b/index.js @@ -46,6 +46,7 @@ class Replicate { * @param {string} options.userAgent - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` + * @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use */ constructor(options = {}) { this.auth = @@ -55,6 +56,7 @@ class Replicate { options.userAgent || `replicate-javascript/${packageJSON.version}`; this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; + this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default"; this.accounts = { current: accounts.current.bind(this), @@ -218,22 +220,32 @@ class Replicate { url.searchParams.append(key, value); } - const headers = {}; + const headers = { + "Content-Type": "application/json", + "User-Agent": userAgent, + }; if (auth) { headers["Authorization"] = `Bearer ${auth}`; } - headers["Content-Type"] = "application/json"; - headers["User-Agent"] = userAgent; if (options.headers) { for (const [key, value] of Object.entries(options.headers)) { headers[key] = value; } } + let body = undefined; + if (data instanceof FormData) { + body = data; + // biome-ignore lint/performance/noDelete: + delete headers["Content-Type"]; // Use automatic content type header + } else if (data) { + body = JSON.stringify(data); + } + const init = { method, headers, - body: data ? JSON.stringify(data) : undefined, + body, }; const shouldRetry = diff --git a/index.test.ts b/index.test.ts index 53737e0..8ecb5ae 100644 --- a/index.test.ts +++ b/index.test.ts @@ -185,50 +185,92 @@ describe("Replicate client", () => { }); describe("predictions.create", () => { - test("Calls the correct API route with the correct payload", async () => { - nock(BASE_URL) - .post("/predictions") - .reply(200, { - id: "ufawqhfynnddngldkgtslldrkq", - model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - urls: { - get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", - }, - created_at: "2022-04-26T22:13:06.224088Z", - started_at: null, - completed_at: null, - status: "starting", - input: { - text: "Alice", + const predictionTestCases = [ + { + description: "String input", + input: { + text: "Alice", + }, + }, + { + description: "Number input", + input: { + text: 123, + }, + }, + { + description: "Boolean input", + input: { + text: true, + }, + }, + { + description: "Array input", + input: { + text: ["Alice", "Bob", "Charlie"], + }, + }, + { + description: "Object input", + input: { + text: { + name: "Alice", }, - output: null, - error: null, - logs: null, - metrics: {}, - }); - const prediction = await client.predictions.create({ + }, + }, + ].map((testCase) => ({ + ...testCase, + expectedResponse: { + id: "ufawqhfynnddngldkgtslldrkq", + model: "replicate/hello-world", version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - text: "Alice", + urls: { + get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }); - expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq"); - }); + input: testCase.input, + created_at: "2022-04-26T22:13:06.224088Z", + started_at: null, + completed_at: null, + status: "starting", + }, + })); - test.each([ + test.each(predictionTestCases)( + "$description", + async ({ input, expectedResponse }) => { + nock(BASE_URL) + .post("/predictions", { + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }) + .reply(200, expectedResponse); + + const response = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); + + expect(response.input).toEqual(input); + expect(response.status).toBe(expectedResponse.status); + } + ); + + const fileTestCases = [ // Skip test case if File type is not available ...(typeof File !== "undefined" ? [ { type: "file", - value: new File(["hello world"], "hello.txt", { + value: new File(["hello world"], "file_hello.txt", { type: "text/plain", }), expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=", @@ -245,11 +287,63 @@ describe("Replicate client", () => { value: Buffer.from("hello world"), expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=", }, - ])( + ]; + + test.each(fileTestCases)( + "converts a $type input into a Replicate file URL", + async ({ value: data, type }) => { + const mockedFetch = jest.spyOn(client, "fetch"); + + nock(BASE_URL) + .post("/files") + .reply(201, { + urls: { + get: "https://replicate.com/api/files/123", + }, + }) + .post( + "/predictions", + (body) => body.input.data === "https://replicate.com/api/files/123" + ) + .reply(201, (_uri: string, body: Record) => { + return body; + }); + + const prediction = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + prompt: "Tell me a story", + data, + }, + }); + + expect(client.fetch).toHaveBeenCalledWith( + new URL("https://api.replicate.com/v1/files"), + { + method: "POST", + body: expect.any(FormData), + headers: expect.any(Object), + } + ); + const form = mockedFetch.mock.calls[0][1]?.body as FormData; + // @ts-ignore + expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); + + expect(prediction.input).toEqual({ + prompt: "Tell me a story", + data: "https://replicate.com/api/files/123", + }); + } + ); + + test.each(fileTestCases)( "converts a $type input into a base64 encoded string", async ({ value: data, expected }) => { let actual: Record | undefined; nock(BASE_URL) + .post("/files") + .reply(503, "Service Unavailable") .post("/predictions") .reply(201, (_uri: string, body: Record) => { actual = body; diff --git a/lib/deployments.js b/lib/deployments.js index 4f6f3c6..27a2f6a 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -30,7 +30,11 @@ async function createPrediction(deployment_owner, deployment_name, options) { method: "POST", data: { ...data, - input: await transformFileInputs(input), + input: await transformFileInputs( + this, + input, + this.fileEncodingStrategy + ), stream, }, } diff --git a/lib/files.js b/lib/files.js new file mode 100644 index 0000000..520f2aa --- /dev/null +++ b/lib/files.js @@ -0,0 +1,90 @@ +/** + * Create a file + * + * @param {object} file - Required. The file object. + * @param {object} metadata - Optional. User-provided metadata associated with the file. + * @returns {Promise} - Resolves with the file data + */ +async function createFile(file, metadata = {}) { + const form = new FormData(); + + let filename; + let blob; + if (file instanceof Blob) { + filename = file.name || `blob_${Date.now()}`; + blob = file; + } else if (Buffer.isBuffer(file)) { + filename = `buffer_${Date.now()}`; + const bytes = new Uint8Array(file); + blob = new Blob([bytes], { + type: "application/octet-stream", + name: filename, + }); + } else { + throw new Error("Invalid file argument, must be a Blob, File or Buffer"); + } + + form.append("content", blob, filename); + form.append( + "metadata", + new Blob([JSON.stringify(metadata)], { type: "application/json" }) + ); + + const response = await this.request("/files", { + method: "POST", + data: form, + headers: { + "Content-Type": "multipart/form-data", + }, + }); + + return response.json(); +} + +/** + * List all files + * + * @returns {Promise} - Resolves with the files data + */ +async function listFiles() { + const response = await this.request("/files", { + method: "GET", + }); + + return response.json(); +} + +/** + * Get a file + * + * @param {string} file_id - Required. The ID of the file. + * @returns {Promise} - Resolves with the file data + */ +async function getFile(file_id) { + const response = await this.request(`/files/${file_id}`, { + method: "GET", + }); + + return response.json(); +} + +/** + * Delete a file + * + * @param {string} file_id - Required. The ID of the file. + * @returns {Promise} - Resolves with the deletion confirmation + */ +async function deleteFile(file_id) { + const response = await this.request(`/files/${file_id}`, { + method: "DELETE", + }); + + return response.json(); +} + +module.exports = { + create: createFile, + list: listFiles, + get: getFile, + delete: deleteFile, +}; diff --git a/lib/predictions.js b/lib/predictions.js index 5b0370e..c290d40 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -30,7 +30,11 @@ async function createPrediction(options) { method: "POST", data: { ...data, - input: await transformFileInputs(input), + input: await transformFileInputs( + this, + input, + this.fileEncodingStrategy + ), version, stream, }, @@ -40,7 +44,11 @@ async function createPrediction(options) { method: "POST", data: { ...data, - input: await transformFileInputs(input), + input: await transformFileInputs( + this, + input, + this.fileEncodingStrategy + ), stream, }, }); diff --git a/lib/util.js b/lib/util.js index 68b1d9d..3745d9f 100644 --- a/lib/util.js +++ b/lib/util.js @@ -1,4 +1,5 @@ const ApiError = require("./error"); +const { create: createFile } = require("./files"); /** * @see {@link validateWebhook} @@ -209,12 +210,58 @@ async function withAutomaticRetries(request, options = {}) { } attempts += 1; } - /* eslint-enable no-await-in-loop */ } while (attempts < maxRetries); return request(); } +/** + * Walks the inputs and, for any File or Blob, tries to upload it to Replicate + * and replaces the input with the URL of the uploaded file. + * + * @param {Replicate} client - The client used to upload the file + * @param {object} inputs - The inputs to transform + * @param {"default" | "upload" | "data-uri"} strategy - Whether to upload files to Replicate, encode as dataURIs or try both. + * @returns {object} - The transformed inputs + * @throws {ApiError} If the request to upload the file fails + */ +async function transformFileInputs(client, inputs, strategy) { + switch (strategy) { + case "data-uri": + return await transformFileInputsToBase64EncodedDataURIs(client, inputs); + case "upload": + return await transformFileInputsToReplicateFileURLs(client, inputs); + case "default": + try { + return await transformFileInputsToReplicateFileURLs(client, inputs); + } catch (error) { + return await transformFileInputsToBase64EncodedDataURIs(inputs); + } + default: + throw new Error(`Unexpected file upload strategy: ${strategy}`); + } +} + +/** + * Walks the inputs and, for any File or Blob, tries to upload it to Replicate + * and replaces the input with the URL of the uploaded file. + * + * @param {Replicate} client - The client used to upload the file + * @param {object} inputs - The inputs to transform + * @returns {object} - The transformed inputs + * @throws {ApiError} If the request to upload the file fails + */ +async function transformFileInputsToReplicateFileURLs(client, inputs) { + return await transform(inputs, async (value) => { + if (value instanceof Blob || value instanceof Buffer) { + const file = await createFile.call(client, value); + return file.urls.get; + } + + return value; + }); +} + const MAX_DATA_URI_SIZE = 10_000_000; /** @@ -225,9 +272,9 @@ const MAX_DATA_URI_SIZE = 10_000_000; * @returns {object} - The transformed inputs * @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE */ -async function transformFileInputs(inputs) { +async function transformFileInputsToBase64EncodedDataURIs(inputs) { let totalBytes = 0; - const result = await transform(inputs, async (value) => { + return await transform(inputs, async (value) => { let buffer; let mime; @@ -258,16 +305,15 @@ async function transformFileInputs(inputs) { return `data:${mime};base64,${data}`; }); - - return result; } // Walk a JavaScript object and transform the leaf values. async function transform(value, mapper) { if (Array.isArray(value)) { - let copy = []; + const copy = []; for (const val of value) { - copy = await transform(val, mapper); + const transformed = await transform(val, mapper); + copy.push(transformed); } return copy; } diff --git a/package-lock.json b/package-lock.json index 955a42d..710a1e9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "replicate", - "version": "0.29.4", + "version": "0.30.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "replicate", - "version": "0.29.4", + "version": "0.30.2", "license": "Apache-2.0", "devDependencies": { "@biomejs/biome": "^1.4.1", diff --git a/package.json b/package.json index 3550449..f31fb22 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "replicate", - "version": "0.29.4", + "version": "0.30.2", "description": "JavaScript client for Replicate", "repository": "github:replicate/replicate-javascript", "homepage": "https://github.com/replicate/replicate-javascript#readme",