diff --git a/README.md b/README.md index d6f0257..36fbcb0 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,8 @@ console.log(prediction.output); // ['https://replicate.delivery/pbxt/RoaxeXqhL0xaYyLm6w3bpGwF5RaNBjADukfFnMbhOyeoWBdhA/out-0.png'] ``` -To run a model that takes a file input, -convert its data into a base64-encoded data URI: +To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can convert file data into a base64-encoded data URI and pass that directly: + ```js import { promises as fs } from "fs"; @@ -174,6 +174,8 @@ const output = await replicate.run(model, { input }); ### `replicate.models.get` +Get metadata for a public model or a private model that you own. + ```js const response = await replicate.models.get(model_owner, model_name); ``` @@ -199,8 +201,45 @@ const response = await replicate.models.get(model_owner, model_name); } ``` +### `replicate.models.list` + +Get a paginated list of all public models. + +```js +const response = await replicate.models.list(); +``` + +```jsonc +{ + "next": null, + "previous": null, + "results": [ + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": { + /* ... */ + }, + "latest_version": { + /* ... */ + } + } + ] +} +``` + ### `replicate.models.versions.list` +Get a list of all published versions of a model, including input and output schemas for each version. + ```js const response = await replicate.models.versions.list(model_owner, model_name); ``` @@ -237,6 +276,8 @@ const response = await replicate.models.versions.list(model_owner, model_name); ### `replicate.models.versions.get` +Get metatadata for a specific version of a model. + ```js const response = await replicate.models.versions.get(model_owner, model_name, version_id); ``` @@ -260,6 +301,8 @@ const response = await replicate.models.versions.get(model_owner, model_name, ve ### `replicate.collections.get` +Get a list of curated model collections. See [replicate.com/collections](https://replicate.com/collections). + ```js const response = await replicate.collections.get(collection_slug); ``` @@ -270,6 +313,8 @@ const response = await replicate.collections.get(collection_slug); ### `replicate.predictions.create` +Run a model with inputs you provide. + ```js const response = await replicate.predictions.create(options); ``` @@ -380,6 +425,8 @@ const response = await replicate.predictions.get(prediction_id); ### `replicate.predictions.cancel` +Stop a running prediction before it finishes. + ```js const response = await replicate.predictions.cancel(prediction_id); ``` @@ -412,6 +459,8 @@ const response = await replicate.predictions.cancel(prediction_id); ### `replicate.predictions.list` +Get a paginated list of all the predictions you've created. + ```js const response = await replicate.predictions.list(); ``` @@ -443,7 +492,7 @@ const response = await replicate.predictions.list(); ### `replicate.trainings.create` -Use the training API to fine-tune language models +Use the [training API](https://replicate.com/docs/fine-tuning) to fine-tune language models to make them better at a particular task. To see what **language models** currently support fine-tuning, check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). @@ -488,6 +537,8 @@ const response = await replicate.trainings.create(model_owner, model_name, versi ### `replicate.trainings.get` +Get metadata and status of a training. + ```js const response = await replicate.trainings.get(training_id); ``` @@ -519,6 +570,8 @@ const response = await replicate.trainings.get(training_id); ### `replicate.trainings.cancel` +Stop a running training job before it finishes. + ```js const response = await replicate.trainings.cancel(training_id); ``` @@ -550,6 +603,8 @@ const response = await replicate.trainings.cancel(training_id); ### `replicate.trainings.list` +Get a paginated list of all the trainings you've run. + ```js const response = await replicate.trainings.list(); ``` @@ -581,6 +636,10 @@ const response = await replicate.trainings.list(); ### `replicate.deployments.predictions.create` +Run a model using your own custom deployment. + +Deployments allow you to run a model with a private, fixed API endpoint. You can configure the version of the model, the hardware it runs on, and how it scales. See the [deployments guide](https://replicate.com/docs/deployments) to learn more and get started. + ```js const response = await replicate.deployments.predictions.create(deployment_owner, deployment_name, options); ``` @@ -620,6 +679,8 @@ const page2 = await paginator.next(); ### `replicate.request` +Low-level method used by the Replicate client to interact with API endpoints. + ```js const response = await replicate.request(route, parameters); ``` diff --git a/index.d.ts b/index.d.ts index 32f279a..601e15b 100644 --- a/index.d.ts +++ b/index.d.ts @@ -117,6 +117,7 @@ declare module 'replicate' { models: { get(model_owner: string, model_name: string): Promise; + list(): Promise>; versions: { list(model_owner: string, model_name: string): Promise; get( diff --git a/index.js b/index.js index c6a2cc2..4f74985 100644 --- a/index.js +++ b/index.js @@ -51,6 +51,7 @@ class Replicate { this.models = { get: models.get.bind(this), + list: models.list.bind(this), versions: { list: models.versions.list.bind(this), get: models.versions.get.bind(this), diff --git a/index.test.ts b/index.test.ts index c35b41f..ab4e9d6 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,5 +1,5 @@ import { expect, jest, test } from '@jest/globals'; -import Replicate, { ApiError, Prediction } from 'replicate'; +import Replicate, { ApiError, Model, Prediction } from 'replicate'; import nock from 'nock'; import fetch from 'cross-fetch'; @@ -131,6 +131,30 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('models.list', () => { + test('Paginates results', async () => { + nock(BASE_URL) + .get('/models') + .reply(200, { + results: [{ url: 'https://replicate.com/some-user/model-1' }], + next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', + }) + .get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') + .reply(200, { + results: [{ url: 'https://replicate.com/some-user/model-2' }], + next: null, + }); + + const results: Model[] = []; + for await (const batch of client.paginate(client.models.list)) { + results.push(...batch); + } + expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]); + + // Add more tests for error handling, edge cases, etc. + }); + }); + describe('predictions.create', () => { test('Calls the correct API route with the correct payload', async () => { nock(BASE_URL) diff --git a/lib/models.js b/lib/models.js index 373ed23..be05750 100644 --- a/lib/models.js +++ b/lib/models.js @@ -37,17 +37,28 @@ async function listModelVersions(model_owner, model_name) { * @returns {Promise} Resolves with the model version data */ async function getModelVersion(model_owner, model_name, version_id) { - const response = await this.request( - `/models/${model_owner}/${model_name}/versions/${version_id}`, - { - method: 'GET', - } - ); + const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}`, { + method: 'GET', + }); + + return response.json(); +} + +/** + * List all public models + * + * @returns {Promise} Resolves with the model version data + */ +async function listModels() { + const response = await this.request('/models', { + method: 'GET', + }); return response.json(); } module.exports = { get: getModel, + list: listModels, versions: { list: listModelVersions, get: getModelVersion }, }; diff --git a/package-lock.json b/package-lock.json index 989f02f..145bcb1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "replicate", - "version": "0.19.0", + "version": "0.20.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "replicate", - "version": "0.19.0", + "version": "0.20.0", "license": "Apache-2.0", "devDependencies": { "@types/jest": "^29.5.3", diff --git a/package.json b/package.json index 131663d..c032728 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "replicate", - "version": "0.19.0", + "version": "0.20.0", "description": "JavaScript client for Replicate", "repository": "github:replicate/replicate-javascript", "homepage": "https://github.com/replicate/replicate-javascript#readme",