Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add support for AbortSignal to all API methods
  • Loading branch information
aron committed Mar 26, 2025
commit 38bc74fa9525b81a6c8fd20b197999285a1121f7
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1215,13 +1215,19 @@ const response = await replicate.request(route, parameters);

| name | type | description |
| -------------------- | ------ | ------------------------------------------------------------ |
| `options.route` | string | Required. REST API endpoint path. |
| `options.parameters` | object | URL, query, and request body parameters for the given route. |
| `options.route` | `string` | Required. REST API endpoint path. |
| `options.params` | `object` | URL query parameters for the given route. |
| `options.method` | `string` | HTTP method for the given route. |
| `options.headers` | `object` | Additional HTTP headers for the given route. |
| `options.data` | `object | FormData` | Request body. |
| `options.signal` | `AbortSignal` | Optional `AbortSignal`. |

The `replicate.request()` method is used by the other methods
to interact with the Replicate API.
You can call this method directly to make other requests to the API.

The method accepts an `AbortSignal` which can be used to cancel the request in flight.

### `FileOutput`

`FileOutput` is a `ReadableStream` instance that represents a model file output. It can be used to stream file data to disk or as a `Response` body to an HTTP request.
Expand Down
108 changes: 76 additions & 32 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,15 @@ declare module "replicate" {
): Promise<Prediction>;

accounts: {
current(): Promise<Account>;
current(options?: { signal?: AbortSignal }): Promise<Account>;
};

collections: {
list(): Promise<Page<Collection>>;
get(collection_slug: string): Promise<Collection>;
list(options?: { signal?: AbortSignal }): Promise<Page<Collection>>;
get(
collection_slug: string,
options?: { signal?: AbortSignal }
): Promise<Collection>;
};

deployments: {
Expand All @@ -221,21 +224,26 @@ declare module "replicate" {
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: number | boolean;
signal?: AbortSignal;
}
): Promise<Prediction>;
};
get(
deployment_owner: string,
deployment_name: string
deployment_name: string,
options?: { signal?: AbortSignal }
): Promise<Deployment>;
create(
deployment_config: {
name: string;
model: string;
version: string;
hardware: string;
min_instances: number;
max_instances: number;
},
options?: { signal?: AbortSignal }
): Promise<Deployment>;
create(deployment_config: {
name: string;
model: string;
version: string;
hardware: string;
min_instances: number;
max_instances: number;
}): Promise<Deployment>;
update(
deployment_owner: string,
deployment_name: string,
Expand All @@ -249,32 +257,45 @@ declare module "replicate" {
| { hardware: string }
| { min_instances: number }
| { max_instances: number }
)
),
options?: { signal?: AbortSignal }
): Promise<Deployment>;
delete(
deployment_owner: string,
deployment_name: string
deployment_name: string,
options?: { signal?: AbortSignal }
): Promise<boolean>;
list(): Promise<Page<Deployment>>;
list(options?: { signal?: AbortSignal }): Promise<Page<Deployment>>;
};

files: {
create(
file: Blob | Buffer,
metadata?: Record<string, unknown>
metadata?: Record<string, unknown>,
options?: { signal?: AbortSignal }
): Promise<FileObject>;
list(): Promise<Page<FileObject>>;
get(file_id: string): Promise<FileObject>;
delete(file_id: string): Promise<boolean>;
list(options?: { signal?: AbortSignal }): Promise<Page<FileObject>>;
get(
file_id: string,
options?: { signal?: AbortSignal }
): Promise<FileObject>;
delete(
file_id: string,
options?: { signal?: AbortSignal }
): Promise<boolean>;
};

hardware: {
list(): Promise<Hardware[]>;
list(options?: { signal?: AbortSignal }): Promise<Hardware[]>;
};

models: {
get(model_owner: string, model_name: string): Promise<Model>;
list(): Promise<Page<Model>>;
get(
model_owner: string,
model_name: string,
options?: { signal?: AbortSignal }
): Promise<Model>;
list(options?: { signal?: AbortSignal }): Promise<Page<Model>>;
create(
model_owner: string,
model_name: string,
Expand All @@ -286,17 +307,26 @@ declare module "replicate" {
paper_url?: string;
license_url?: string;
cover_image_url?: string;
signal?: AbortSignal;
}
): Promise<Model>;
versions: {
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
list(
model_owner: string,
model_name: string,
options?: { signal?: AbortSignal }
): Promise<ModelVersion[]>;
get(
model_owner: string,
model_name: string,
version_id: string
version_id: string,
options?: { signal?: AbortSignal }
): Promise<ModelVersion>;
};
search(query: string): Promise<Page<Model>>;
search(
query: string,
options?: { signal?: AbortSignal }
): Promise<Page<Model>>;
};

predictions: {
Expand All @@ -310,11 +340,18 @@ declare module "replicate" {
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: boolean | number;
signal?: AbortSignal;
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
cancel(prediction_id: string): Promise<Prediction>;
list(): Promise<Page<Prediction>>;
get(
prediction_id: string,
options?: { signal?: AbortSignal }
): Promise<Prediction>;
cancel(
prediction_id: string,
options?: { signal?: AbortSignal }
): Promise<Prediction>;
list(options?: { signal?: AbortSignal }): Promise<Page<Prediction>>;
};

trainings: {
Expand All @@ -327,17 +364,24 @@ declare module "replicate" {
input: object;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
}
): Promise<Training>;
get(training_id: string): Promise<Training>;
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
get(
training_id: string,
options?: { signal?: AbortSignal }
): Promise<Training>;
cancel(
training_id: string,
options?: { signal?: AbortSignal }
): Promise<Training>;
list(options?: { signal?: AbortSignal }): Promise<Page<Training>>;
};

webhooks: {
default: {
secret: {
get(): Promise<WebhookSecret>;
get(options?: { signal?: AbortSignal }): Promise<WebhookSecret>;
};
};
};
Expand Down
5 changes: 4 additions & 1 deletion lib/accounts.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
/**
* Get the current account
*
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} Resolves with the current account
*/
async function getCurrentAccount() {
async function getCurrentAccount({ signal } = {}) {
const response = await this.request("/account", {
method: "GET",
signal,
});

return response.json();
Expand Down
10 changes: 8 additions & 2 deletions lib/collections.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
* Fetch a model collection
*
* @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} - Resolves with the collection data
*/
async function getCollection(collection_slug) {
async function getCollection(collection_slug, { signal } = {}) {
const response = await this.request(`/collections/${collection_slug}`, {
method: "GET",
signal,
});

return response.json();
Expand All @@ -15,11 +18,14 @@ async function getCollection(collection_slug) {
/**
* Fetch a list of model collections
*
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} - Resolves with the collections data
*/
async function listCollections() {
async function listCollections({ signal } = {}) {
const response = await this.request("/collections", {
method: "GET",
signal,
});

return response.json();
Expand Down
40 changes: 33 additions & 7 deletions lib/deployments.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ const { transformFileInputs } = require("./util");
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} Resolves with the created prediction data
*/
async function createPrediction(deployment_owner, deployment_name, options) {
const { input, wait, ...data } = options;
const { input, wait, signal, ...data } = options;

if (data.webhook) {
try {
Expand Down Expand Up @@ -47,6 +48,7 @@ async function createPrediction(deployment_owner, deployment_name, options) {
this.fileEncodingStrategy
),
},
signal,
}
);

Expand All @@ -58,13 +60,20 @@ async function createPrediction(deployment_owner, deployment_name, options) {
*
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {object] [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} Resolves with the deployment data
*/
async function getDeployment(deployment_owner, deployment_name) {
async function getDeployment(
deployment_owner,
deployment_name,
{ signal } = {}
) {
const response = await this.request(
`/deployments/${deployment_owner}/${deployment_name}`,
{
method: "GET",
signal,
}
);

Expand All @@ -84,13 +93,16 @@ async function getDeployment(deployment_owner, deployment_name) {
/**
* Create a deployment
*
* @param {DeploymentCreateRequest} config - Required. The deployment config.
* @param {DeploymentCreateRequest} deployment_config - Required. The deployment config.
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} Resolves with the deployment data
*/
async function createDeployment(deployment_config) {
async function createDeployment(deployment_config, { signal } = {}) {
const response = await this.request("/deployments", {
method: "POST",
data: deployment_config,
signal,
});

return response.json();
Expand All @@ -110,18 +122,22 @@ async function createDeployment(deployment_config) {
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes.
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} Resolves with the deployment data
*/
async function updateDeployment(
deployment_owner,
deployment_name,
deployment_config
deployment_config,
{ signal } = {}
) {
const response = await this.request(
`/deployments/${deployment_owner}/${deployment_name}`,
{
method: "PATCH",
data: deployment_config,
signal,
}
);

Expand All @@ -133,13 +149,20 @@ async function updateDeployment(
*
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<boolean>} Resolves with true if the deployment was deleted
*/
async function deleteDeployment(deployment_owner, deployment_name) {
async function deleteDeployment(
deployment_owner,
deployment_name,
{ signal } = {}
) {
const response = await this.request(
`/deployments/${deployment_owner}/${deployment_name}`,
{
method: "DELETE",
signal,
}
);

Expand All @@ -149,11 +172,14 @@ async function deleteDeployment(deployment_owner, deployment_name) {
/**
* List all deployments
*
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} - Resolves with a page of deployments
*/
async function listDeployments() {
async function listDeployments({ signal } = {}) {
const response = await this.request("/deployments", {
method: "GET",
signal,
});

return response.json();
Expand Down
Loading