Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ import {
type ToolResponsePart,
} from './model.js';
import { isExecutablePrompt } from './prompt.js';
import { isDynamicResourceAction, ResourceAction } from './resource.js';
import {
isDynamicResourceAction,
resolveResources,
ResourceAction,
ResourceArgument,
} from './resource.js';
import {
isDynamicTool,
resolveTools,
toToolDefinition,
type ToolArgument,
} from './tool.js';

export { GenerateResponse, GenerateResponseChunk };

/** Specifies how tools should be called by the model. */
Expand Down Expand Up @@ -121,7 +127,7 @@ export interface GenerateOptions<
/** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */
tools?: ToolArgument[];
/** List of dynamic resources to be made available to this generate request. */
resources?: ResourceAction[];
resources?: ResourceArgument[];
/** Specifies how tools should be called by the model. */
toolChoice?: ToolChoice;
/** Configuration for the generation request. */
Expand Down Expand Up @@ -222,6 +228,10 @@ export async function toGenerateRequest(
if (options.tools) {
tools = await resolveTools(registry, options.tools);
}
let resources: ResourceAction[] | undefined;
if (options.resources) {
resources = await resolveResources(registry, options.resources);
}

const resolvedSchema = toJsonSchema({
schema: options.output?.schema,
Expand All @@ -245,6 +255,7 @@ export async function toGenerateRequest(
config: options.config,
docs: options.docs,
tools: tools?.map(toToolDefinition) || [],
resources: resources?.map((a) => a.__action) || [],
output: {
...(resolvedFormat?.config || {}),
...options.output,
Expand Down Expand Up @@ -285,7 +296,8 @@ async function toolsToActionRefs(

for (const t of toolOpt) {
if (typeof t === 'string') {
tools.push(await resolveFullToolName(registry, t));
const names = await resolveFullToolNames(registry, t);
tools.push(...names);
} else if (isAction(t) || isDynamicTool(t)) {
tools.push(`/${t.__action.metadata?.type}/${t.__action.name}`);
} else if (isExecutablePrompt(t)) {
Expand All @@ -298,6 +310,27 @@ async function toolsToActionRefs(
return tools;
}

async function resourcesToActionRefs(
registry: Registry,
resOpt?: ResourceArgument[]
): Promise<string[] | undefined> {
if (!resOpt) return;

const resources: string[] = [];

for (const r of resOpt) {
if (typeof r === 'string') {
const names = await resolveFullResourceNames(registry, r);
resources.push(...names);
} else if (isAction(r)) {
resources.push(`/resource/${r.__action.name}`);
} else {
throw new Error(`Unable to resolve resource: ${JSON.stringify(r)}`);
}
}
return resources;
}

function messagesFromOptions(options: GenerateOptions): MessageData[] {
const messages: MessageData[] = [];
if (options.system) {
Expand Down Expand Up @@ -358,6 +391,10 @@ export async function generate<
const params = await toGenerateActionOptions(registry, resolvedOptions);

const tools = await toolsToActionRefs(registry, resolvedOptions.tools);
const resources = await resourcesToActionRefs(
registry,
resolvedOptions.resources
);
const streamingCallback = stripNoop(
resolvedOptions.onChunk ?? resolvedOptions.streamingCallback
) as StreamingCallback<GenerateResponseChunkData>;
Expand All @@ -372,6 +409,7 @@ export async function generate<
const request = await toGenerateRequest(registry, {
...resolvedOptions,
tools,
resources,
});
return new GenerateResponse<O>(response, {
request: response.request ?? request,
Expand Down Expand Up @@ -458,6 +496,7 @@ export async function toGenerateActionOptions<
): Promise<GenerateActionOptions> {
const resolvedModel = await resolveModel(registry, options.model);
const tools = await toolsToActionRefs(registry, options.tools);
const resources = await resourcesToActionRefs(registry, options.resources);
const messages: MessageData[] = messagesFromOptions(options);

const resolvedSchema = toJsonSchema({
Expand All @@ -478,6 +517,7 @@ export async function toGenerateActionOptions<
docs: options.docs,
messages: messages,
tools,
resources,
toolChoice: options.toolChoice,
config: {
version: resolvedModel.version,
Expand Down Expand Up @@ -530,17 +570,49 @@ function stripUndefinedOptions(input?: any): any {
return copy;
}

async function resolveFullToolName(
async function resolveFullToolNames(
registry: Registry,
name: string
): Promise<string> {
): Promise<string[]> {
let names: string[];
const parts = name.split(':');
if (parts.length > 1) {
// Dynamic Action Provider
names = await registry.resolveActionNames(
`/dynamic-action-provider/${name}`
);
if (names.length) {
return names;
}
}
if (await registry.lookupAction(`/tool/${name}`)) {
return `/tool/${name}`;
} else if (await registry.lookupAction(`/prompt/${name}`)) {
return `/prompt/${name}`;
} else {
throw new Error(`Unable to determine type of of tool: ${name}`);
return [`/tool/${name}`];
}
if (await registry.lookupAction(`/prompt/${name}`)) {
return [`/prompt/${name}`];
}
throw new Error(`Unable to resolve tool: ${name}`);
}

async function resolveFullResourceNames(
registry: Registry,
name: string
): Promise<string[]> {
let names: string[];
const parts = name.split(':');
if (parts.length > 1) {
// Dynamic Action Provider
names = await registry.resolveActionNames(
`/dynamic-action-provider/${name}`
);
if (names.length) {
return names;
}
}
if (await registry.lookupAction(`/resource/${name}`)) {
return [`/resource/${name}`];
}
throw new Error(`Unable to resolve resource: ${name}`);
}

export type GenerateStreamOptions<
Expand Down
25 changes: 18 additions & 7 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ import {
type Part,
type Role,
} from '../model.js';
import { findMatchingResource } from '../resource.js';
import {
findMatchingResource,
resolveResources,
type ResourceAction,
} from '../resource.js';
import { resolveTools, toToolDefinition, type ToolAction } from '../tool.js';
import {
assertValidToolNames,
Expand Down Expand Up @@ -151,14 +155,15 @@ async function resolveParameters(
registry: Registry,
request: GenerateActionOptions
) {
const [model, tools, format] = await Promise.all([
const [model, tools, resources, format] = await Promise.all([
resolveModel(registry, request.model, { warnDeprecated: true }).then(
(r) => r.modelAction
),
resolveTools(registry, request.tools),
resolveResources(registry, request.resources),
resolveFormat(registry, request.output),
]);
return { model, tools, format };
return { model, tools, resources, format };
}

/** Given a raw request and a formatter, apply the formatter's logic and instructions to the request. */
Expand Down Expand Up @@ -246,12 +251,12 @@ async function generate(
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
}
): Promise<GenerateResponseData> {
const { model, tools, format } = await resolveParameters(
const { model, tools, resources, format } = await resolveParameters(
registry,
rawRequest
);
rawRequest = applyFormat(rawRequest, format);
rawRequest = await applyResources(registry, rawRequest);
rawRequest = await applyResources(registry, rawRequest, resources);

// check to make sure we don't have overlapping tool names *before* generation
await assertValidToolNames(tools);
Expand Down Expand Up @@ -481,7 +486,8 @@ function getRoleFromPart(part: Part): Role {

async function applyResources(
registry: Registry,
rawRequest: GenerateActionOptions
rawRequest: GenerateActionOptions,
resources: ResourceAction[]
): Promise<GenerateActionOptions> {
// quick check, if no resources bail.
if (!rawRequest.messages.find((m) => !!m.content.find((c) => c.resource))) {
Expand All @@ -500,7 +506,12 @@ async function applyResources(
updatedContent.push(p);
continue;
}
const resource = await findMatchingResource(registry, p.resource);

const resource = await findMatchingResource(
registry,
resources,
p.resource
);
if (!resource) {
throw new GenkitError({
status: 'NOT_FOUND',
Expand Down
2 changes: 2 additions & 0 deletions js/ai/src/model-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ export const GenerateActionOptionsSchema = z.object({
messages: z.array(MessageSchema),
/** List of registered tool names for this generation if supported by the underlying model. */
tools: z.array(z.string()).optional(),
/** List of registered resource names for this generation if supported by the underlying model. */
resources: z.array(z.string()).optional(),
/** Tool calling mode. `auto` lets the model decide whether to use tools, `required` forces the model to choose a tool, and `none` forces the model not to use any tools. Defaults to `auto`. */
toolChoice: z.enum(['auto', 'required', 'none']).optional(),
/** Configuration for the generation request. */
Expand Down
62 changes: 59 additions & 3 deletions js/ai/src/resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,45 @@ export interface ResourceAction
matches(input: ResourceInput): boolean;
}

/**
* A reference to a resource in the form of a name or a ResourceAction.
*/
export type ResourceArgument = ResourceAction | string;

export async function resolveResources(
registry: Registry,
resources?: ResourceArgument[]
): Promise<ResourceAction[]> {
if (!resources || resources.length === 0) {
return [];
}

return await Promise.all(
resources.map(async (ref): Promise<ResourceAction> => {
if (typeof ref === 'string') {
return await lookupResourceByName(registry, ref);
} else if (isAction(ref)) {
return ref;
}
throw new Error('Resources must be strings, or actions');
})
);
}

export async function lookupResourceByName(
registry: Registry,
name: string
): Promise<ResourceAction> {
const resource =
(await registry.lookupAction(name)) ||
(await registry.lookupAction(`/resource/${name}`)) ||
(await registry.lookupAction(`/dynamic-action-provider/${name}`));
if (!resource) {
throw new Error(`Resource ${name} not found`);
}
return resource as ResourceAction;
}

/**
* Defines a resource.
*
Expand Down Expand Up @@ -122,11 +161,28 @@ export type DynamicResourceAction = ResourceAction & {
*/
export async function findMatchingResource(
registry: Registry,
resources: ResourceAction[],
input: ResourceInput
): Promise<ResourceAction | undefined> {
for (const actKeys of Object.keys(await registry.listResolvableActions())) {
if (actKeys.startsWith('/resource/')) {
const resource = (await registry.lookupAction(actKeys)) as ResourceAction;
// First look in any resources explicitly listed in the generate request
for (const res of resources) {
if (res.matches(input)) {
return res;
}
}

// Then search the registry
for (const registryKey of Object.keys(
await registry.listResolvableActions()
)) {
// We decided not to look in DAP actions because they might be slow.
// DAP actions with resources will only be found if they are listed in the
// resources section, and then they will be found above.
if (registryKey.startsWith('/resource/')) {
const resource = (await registry.lookupAction(
registryKey
)) as ResourceAction;

if (resource.matches(input)) {
return resource;
}
Expand Down
3 changes: 2 additions & 1 deletion js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ export async function lookupToolByName(
const tool =
(await registry.lookupAction(name)) ||
(await registry.lookupAction(`/tool/${name}`)) ||
(await registry.lookupAction(`/prompt/${name}`));
(await registry.lookupAction(`/prompt/${name}`)) ||
(await registry.lookupAction(`/dynamic-action-provider/${name}`));
if (!tool) {
throw new Error(`Tool ${name} not found`);
}
Expand Down
Loading