Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/heavy-towns-beam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@storybook/mcp': patch
---

Allow undefined request in server context when using custom manifestProvider
187 changes: 187 additions & 0 deletions packages/mcp/bin.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/**
* Integration tests for the stdio MCP server in bin.ts
*
* These tests spawn the bin.ts process as a child process and communicate
* with it via stdin/stdout, simulating how an MCP client would interact
* with the server in production.
*/
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
import { x } from 'tinyexec';
import { resolve, dirname } from 'node:path';
import { fileURLToPath } from 'node:url';
import type { ChildProcess } from 'node:child_process';

/**
* Helper to send a JSON-RPC request and wait for the response
*/
async function sendRequest(
child: ChildProcess,
stdoutData: string[],
request: unknown,
requestId: number,
timeoutMs = 10_000,
): Promise<unknown> {
// Send request
child.stdin?.write(JSON.stringify(request) + '\n');

// Wait for response with timeout
const { promise, resolve, reject } = Promise.withResolvers<void>();
const timeout = setTimeout(() => {
reject(new Error(`Timeout waiting for response to request ${requestId}`));
}, timeoutMs);

const checkResponse = () => {
const allData = stdoutData.join('');
if (allData.includes(`"id":${requestId}`)) {
clearTimeout(timeout);
resolve();
} else {
setTimeout(checkResponse, 50);
}
};
checkResponse();

await promise;

// Parse and return the response
const allData = stdoutData.join('');
const lines = allData.split('\n').filter((line) => line.trim());
const responseLine = lines.find((line) => {
try {
const parsed = JSON.parse(line);
return parsed.id === requestId;
} catch {
return false;
}
});

if (!responseLine) {
throw new Error(`No response found for request ${requestId}`);
}

return JSON.parse(responseLine);
}

describe('bin.ts stdio MCP server', () => {
let child: ChildProcess;
let stdoutData: string[] = [];
let stderrData: string[] = [];

beforeAll(() => {
const currentDir = dirname(fileURLToPath(import.meta.url));
const binPath = resolve(currentDir, './bin.ts');
const fixturePath = resolve(
currentDir,
'./fixtures/full-manifest.fixture.json',
);

const proc = x('node', [binPath, '--manifestPath', fixturePath]);

child = proc.process as ChildProcess;

// Collect stdout for later assertions
child.stdout?.on('data', (chunk) => {
stdoutData.push(chunk.toString());
});

// Collect stderr for debugging
child.stderr?.on('data', (chunk) => {
stderrData.push(chunk.toString());
});

child.on('error', (err) => {
console.error('Process error:', err);
});
});

afterAll(() => {
child.kill();
});

it('should respond to initialize request', async () => {
const request = {
jsonrpc: '2.0',
id: 1,
method: 'initialize',
params: {
protocolVersion: '2024-11-05',
capabilities: {},
clientInfo: {
name: 'test-client',
version: '1.0.0',
},
},
};

const response = await sendRequest(child, stdoutData, request, 1);

expect(response).toMatchObject({
jsonrpc: '2.0',
id: 1,
result: {
protocolVersion: '2024-11-05',
capabilities: {
tools: {
listChanged: true,
},
},
serverInfo: {
name: '@storybook/mcp',
},
},
});
}, 15000);

it('should list available tools', async () => {
const request = {
jsonrpc: '2.0',
id: 2,
method: 'tools/list',
params: {},
};

const response = await sendRequest(child, stdoutData, request, 2);

expect(response).toMatchObject({
jsonrpc: '2.0',
id: 2,
result: {
tools: expect.arrayContaining([
expect.objectContaining({
name: 'list-all-components',
}),
expect.objectContaining({
name: 'get-component-documentation',
}),
]),
},
});
}, 15000);

it('should execute list-all-components tool', async () => {
const request = {
jsonrpc: '2.0',
id: 3,
method: 'tools/call',
params: {
name: 'list-all-components',
arguments: {},
},
};

const response = await sendRequest(child, stdoutData, request, 3);

expect(response).toMatchObject({
jsonrpc: '2.0',
id: 3,
result: {
content: [
{
type: 'text',
text: expect.stringContaining('<components>'),
},
],
},
});
}, 15000);
});
3 changes: 2 additions & 1 deletion packages/mcp/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"devDependencies": {
"@tmcp/transport-stdio": "catalog:",
"react-docgen": "^8.0.2",
"srvx": "^0.8.16"
"srvx": "^0.8.16",
"tinyexec": "^1.0.2"
}
}
10 changes: 7 additions & 3 deletions packages/mcp/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ export interface StorybookContext extends Record<string, unknown> {
* If provided, this function will be called instead of the default fetch-based provider.
* The function receives the request object and a path to the manifest file,
* and should return the manifest as a string.
* The default provider constructs the manifest URL from the request origin,
* replacing /mcp with /manifests/components.json
* The default provider requires a request object and constructs the manifest URL from the request origin,
* replacing /mcp with /manifests/components.json.
* Custom providers can use the request parameter to determine the manifest source, or ignore it entirely.
*/
manifestProvider?: (request: Request, path: string) => Promise<string>;
manifestProvider?: (
request: Request | undefined,
path: string,
) => Promise<string>;
/**
* Optional handler called when list-all-components tool is invoked.
* Receives the context and the component manifest.
Expand Down
40 changes: 35 additions & 5 deletions packages/mcp/src/utils/get-manifest.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ describe('getManifest', () => {
});

describe('error cases', () => {
it('should throw ManifestGetError when request is not provided', async () => {
it('should throw ManifestGetError when request is not provided and using default provider', async () => {
await expect(getManifest()).rejects.toThrow(ManifestGetError);
await expect(getManifest()).rejects.toThrow(
'The request is required but was not provided in the context',
"You must either pass the original request forward to the server context, or set a custom manifestProvider that doesn't need the request",
);
});

it('should throw ManifestGetError when request is undefined', async () => {
it('should throw ManifestGetError when request is undefined and using default provider', async () => {
await expect(getManifest(undefined)).rejects.toThrow(ManifestGetError);
await expect(getManifest(undefined)).rejects.toThrow(
"You must either pass the original request forward to the server context, or set a custom manifestProvider that doesn't need the request",
);
});

it('should throw ManifestGetError when fetch fails with 404', async () => {
global.fetch = vi.fn().mockResolvedValue({
ok: false,
Expand Down Expand Up @@ -225,6 +226,35 @@ describe('getManifest', () => {
expect(global.fetch).not.toHaveBeenCalled();
});

it('should allow manifestProvider to work without request', async () => {
const validManifest: ComponentManifestMap = {
v: 1,
components: {
button: {
id: 'button',
path: 'src/components/Button.tsx',
name: 'Button',
description: 'A button component',
},
},
};

// Custom provider that doesn't need the request
const manifestProvider = vi
.fn()
.mockResolvedValue(JSON.stringify(validManifest));

const result = await getManifest(undefined, manifestProvider);

expect(result).toEqual(validManifest);
expect(manifestProvider).toHaveBeenCalledExactlyOnceWith(
undefined,
'./manifests/components.json',
);
// fetch should not be called when manifestProvider is used
expect(global.fetch).not.toHaveBeenCalled();
});

it('should fallback to fetch when manifestProvider is not provided', async () => {
const validManifest: ComponentManifestMap = {
v: 1,
Expand Down
30 changes: 19 additions & 11 deletions packages/mcp/src/utils/get-manifest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,18 @@ export const errorToMCPContent = (error: unknown): MCPErrorResult => {
/**
* Gets a component manifest from a request or using a custom provider
*
* @param request - The HTTP request to get the manifest for
* @param request - The HTTP request to get the manifest for (optional when using custom manifestProvider)
* @param manifestProvider - Optional custom function to get the manifest
* @returns A promise that resolves to the parsed ComponentManifestMap
* @throws {ManifestGetError} If getting the manifest fails or the response is invalid
*/
export async function getManifest(
request?: Request,
manifestProvider?: (request: Request, path: string) => Promise<string>,
manifestProvider?: (
request: Request | undefined,
path: string,
) => Promise<string>,
): Promise<ComponentManifestMap> {
if (!request) {
throw new ManifestGetError(
'The request is required but was not provided in the context',
);
}
try {
// Use custom manifestProvider if provided, otherwise fallback to default
const manifestString = await (manifestProvider ?? defaultManifestProvider)(
Expand All @@ -89,7 +87,9 @@ export async function getManifest(
const manifest = v.parse(ComponentManifestMap, manifestData);

if (Object.keys(manifest.components).length === 0) {
const url = getManifestUrlFromRequest(request, MANIFEST_PATH);
const url = request
? getManifestUrlFromRequest(request, MANIFEST_PATH)
: 'Unknown manifest source';
throw new ManifestGetError(`No components found in the manifest`, url);
}

Expand All @@ -100,9 +100,12 @@ export async function getManifest(
}

// Wrap network errors and other unexpected errors
const url = request
? getManifestUrlFromRequest(request, MANIFEST_PATH)
: 'Unknown manifest source';
throw new ManifestGetError(
`Failed to get manifest: ${error instanceof Error ? error.message : String(error)}`,
getManifestUrlFromRequest(request, MANIFEST_PATH),
url,
error instanceof Error ? error : undefined,
);
}
Expand All @@ -125,9 +128,14 @@ function getManifestUrlFromRequest(request: Request, path: string): string {
* replacing /mcp with the provided path
*/
async function defaultManifestProvider(
request: Request,
path: string,
request?: Request,
path: string = './manifests/components.json',
): Promise<string> {
if (!request) {
throw new ManifestGetError(
"Request is required when using the default manifest provider. You must either pass the original request forward to the server context, or set a custom manifestProvider that doesn't need the request.",
);
}
const manifestUrl = getManifestUrlFromRequest(request, path);
const response = await fetch(manifestUrl);

Expand Down
Loading
Loading