diff --git a/packages/compass-e2e-tests/tests/collection-ai-query.test.ts b/packages/compass-e2e-tests/tests/collection-ai-query.test.ts index 40325cd81bc..cb1dfa10f00 100644 --- a/packages/compass-e2e-tests/tests/collection-ai-query.test.ts +++ b/packages/compass-e2e-tests/tests/collection-ai-query.test.ts @@ -1,8 +1,6 @@ import { expect } from 'chai'; import type { CompassBrowser } from '../helpers/compass-browser'; -import { startTelemetryServer } from '../helpers/telemetry'; -import type { Telemetry } from '../helpers/telemetry'; import { init, cleanup, @@ -13,167 +11,8 @@ import { import type { Compass } from '../helpers/compass'; import * as Selectors from '../helpers/selectors'; import { createNumbersCollection } from '../helpers/insert-data'; -import { startMockAtlasServiceServer } from '../helpers/mock-atlas-service'; -import type { MockAtlasServerResponse } from '../helpers/mock-atlas-service'; import { startMockAssistantServer } from '../helpers/assistant-service'; -describe('Collection ai query (with mocked backend)', function () { - let compass: Compass; - let browser: CompassBrowser; - let telemetry: Telemetry; - let setMockAtlasServerResponse: (response: MockAtlasServerResponse) => void; - let stopMockAtlasServer: () => Promise; - let getRequests: () => any[]; - let clearRequests: () => void; - - before(async function () { - // Start a mock server to pass an ai response. - const { - endpoint, - getRequests: _getRequests, - clearRequests: _clearRequests, - setMockAtlasServerResponse: _setMockAtlasServerResponse, - stop, - } = await startMockAtlasServiceServer(); - - stopMockAtlasServer = stop; - getRequests = _getRequests; - clearRequests = _clearRequests; - setMockAtlasServerResponse = _setMockAtlasServerResponse; - - telemetry = await startTelemetryServer(); - compass = await init(this.test?.fullTitle()); - browser = compass.browser; - - await browser.setEnv( - 'COMPASS_ATLAS_SERVICE_UNAUTH_BASE_URL_OVERRIDE', - endpoint - ); - - await browser.setFeature('enableGenAIFeatures', true); - await browser.setFeature('enableGenAISampleDocumentPassing', true); - await browser.setFeature('optInGenAIFeatures', true); - - await browser.setupDefaultConnections(); - }); - - beforeEach(async function () { - await createNumbersCollection(); - await browser.disconnectAll(); - await browser.connectToDefaults(); - await browser.navigateToCollectionTab( - DEFAULT_CONNECTION_NAME_1, - 'test', - 'numbers', - 'Documents' - ); - }); - - after(async function () { - await stopMockAtlasServer(); - - await cleanup(compass); - await telemetry.stop(); - }); - - afterEach(async function () { - clearRequests(); - await screenshotIfFailed(compass, this.currentTest); - }); - - describe('when the ai model response is valid', function () { - beforeEach(function () { - setMockAtlasServerResponse({ - status: 200, - body: { - content: { - query: { - filter: '{i: {$gt: 50}}', - }, - }, - }, - }); - }); - - it('makes request to the server and updates the query bar with the response', async function () { - // Click the ai entry button. - await browser.clickVisible(Selectors.GenAIEntryButton); - - // Enter the ai prompt. - await browser.clickVisible(Selectors.GenAITextInput); - - const testUserInput = 'find all documents where i is greater than 50'; - await browser.setValueVisible(Selectors.GenAITextInput, testUserInput); - - // Click generate. - await browser.clickVisible(Selectors.GenAIGenerateQueryButton); - - // Wait for the ipc events to succeed. - await browser.waitUntil(async function () { - // Make sure the query bar was updated. - const queryBarFilterContent = await browser.getCodemirrorEditorText( - Selectors.queryBarOptionInputFilter('Documents') - ); - return queryBarFilterContent === '{i: {$gt: 50}}'; - }); - - // Check that the request was made with the correct parameters. - const requests = getRequests(); - expect(requests.length).to.equal(1); - - const queryRequest = requests[0]; - const queryURL = new URL( - queryRequest.req.url, - `http://${queryRequest.req.headers.host}` - ); - expect([...new Set(queryURL.searchParams.keys())].length).to.equal(1); - const requestId = queryURL.searchParams.get('request_id'); - expect((requestId?.match(/-/g) || []).length).to.equal(4); // Is uuid like. - expect(queryRequest.content.userInput).to.equal(testUserInput); - expect(queryRequest.content.collectionName).to.equal('numbers'); - expect(queryRequest.content.databaseName).to.equal('test'); - expect(queryRequest.content.schema).to.exist; - - // Run it and check that the correct documents are shown. - await browser.runFind('Documents', true); - const modifiedResult = await browser.getFirstListDocument(); - expect(modifiedResult.i).to.be.equal('51'); - }); - }); - - describe('when the Atlas service request errors', function () { - beforeEach(function () { - setMockAtlasServerResponse({ - status: 500, - body: { - content: 'error', - }, - }); - }); - - it('the error is shown to the user', async function () { - // Click the ai entry button. - await browser.clickVisible(Selectors.GenAIEntryButton); - - // Enter the ai prompt. - await browser.clickVisible(Selectors.GenAITextInput); - - const testUserInput = 'find all documents where i is greater than 50'; - await browser.setValueVisible(Selectors.GenAITextInput, testUserInput); - - // Click generate. - await browser.clickVisible(Selectors.GenAIGenerateQueryButton); - - // Check that the error is shown. - const errorBanner = browser.$(Selectors.GenAIErrorMessageBanner); - await errorBanner.waitForDisplayed(); - expect(await errorBanner.getText()).to.equal( - 'Sorry, we were unable to generate the query, please try again. If the error persists, try changing your prompt.' - ); - }); - }); -}); - async function setup( browser: CompassBrowser, dbName: string, @@ -189,13 +28,12 @@ async function setup( 'Documents' ); - await browser.setFeature('enableChatbotEndpointForGenAI', true); await browser.setFeature('enableGenAIFeatures', true); await browser.setFeature('enableGenAISampleDocumentPassing', true); await browser.setFeature('optInGenAIFeatures', true); } -describe('Collection ai query with chatbot (with mocked backend)', function () { +describe('Collection ai query (with mocked backend)', function () { const dbName = 'test'; const collName = 'numbers'; let compass: Compass; @@ -266,9 +104,8 @@ describe('Collection ai query with chatbot (with mocked backend)', function () { const queryRequest = requests[0]; expect(queryRequest.req.headers).to.have.property('x-client-request-id'); - // TODO(COMPASS-10125): Switch the model to `mongodb-slim-latest` when - // enabling this feature. - expect(queryRequest.content.model).to.equal('mongodb-chat-latest'); + expect(queryRequest.req.headers).to.have.property('entrypoint'); + expect(queryRequest.content.model).to.equal('mongodb-slim-latest'); expect(queryRequest.content.instructions).to.be.string; expect(queryRequest.content.metadata).to.have.property('userId'); expect(queryRequest.content.metadata.store).to.have.equal('true'); diff --git a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts index 1c16fca38ce..5c05c322c67 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts @@ -86,27 +86,9 @@ describe('AtlasAiService', function () { global.fetch = initialFetch; }); - const endpointBasepathTests = [ - { - apiURLPreset: 'admin-api', - expectedEndpoints: { - 'mql-aggregation': `http://example.com/unauth/ai/api/v1/mql-aggregation?request_id=abc`, - 'mql-query': `http://example.com/unauth/ai/api/v1/mql-query?request_id=abc`, - }, - }, - { - apiURLPreset: 'cloud', - expectedEndpoints: { - 'mql-aggregation': - '/cloud/ai/v1/groups/testProject/mql-aggregation?request_id=abc', - 'mql-query': '/cloud/ai/v1/groups/testProject/mql-query?request_id=abc', - 'mock-data-schema': - '/cloud/ai/v1/groups/testProject/mock-data-schema?request_id=abc', - }, - }, - ] as const; + const endpointBasepathTests = ['admin-api', 'cloud'] as const; - for (const { apiURLPreset, expectedEndpoints } of endpointBasepathTests) { + for (const apiURLPreset of endpointBasepathTests) { const describeName = apiURLPreset === 'admin-api' ? 'connection WITHOUT atlas metadata' @@ -130,189 +112,6 @@ describe('AtlasAiService', function () { }); }); - describe('getQueryFromUserInput and getAggregationFromUserInput', function () { - beforeEach(async function () { - // Enable the AI feature - const fetchStub = sandbox.stub().resolves( - makeResponse({ - features: { - GEN_AI_COMPASS: { - enabled: true, - }, - }, - }) - ); - global.fetch = fetchStub; - await atlasAiService['setupAIAccess'](); - global.fetch = initialFetch; - }); - - const atlasAIServiceTests = [ - { - functionName: 'getQueryFromUserInput', - aiEndpoint: 'mql-query', - responses: { - success: { - content: { query: { filter: "{ test: 'pineapple' }" } }, - }, - invalid: [ - [undefined, 'internal server error'], - [{}, 'unexpected response'], - [{ countent: {} }, 'unexpected response'], - [{ content: { qooery: {} } }, 'unexpected keys'], - [ - { content: { query: { filter: { foo: 1 } } } }, - 'unexpected response', - ], - ], - }, - }, - { - functionName: 'getAggregationFromUserInput', - aiEndpoint: 'mql-aggregation', - responses: { - success: { - content: { - aggregation: { pipeline: "[{ test: 'pineapple' }]" }, - }, - }, - invalid: [ - [undefined, 'internal server error'], - [{}, 'unexpected response'], - [{ content: { aggregation: {} } }, 'unexpected response'], - [{ content: { aggrogation: {} } }, 'unexpected keys'], - [ - { content: { aggregation: { pipeline: true } } }, - 'unexpected response', - ], - ], - }, - }, - ] as const; - - for (const { - functionName, - aiEndpoint, - responses, - } of atlasAIServiceTests) { - describe(functionName, function () { - it('makes a post request with the user input to the endpoint in the environment', async function () { - const fetchStub = sandbox - .stub() - .resolves(makeResponse(responses.success)); - global.fetch = fetchStub; - - const res = await atlasAiService[functionName]( - { - userInput: 'test', - signal: new AbortController().signal, - collectionName: 'jam', - databaseName: 'peanut', - schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, - sampleDocuments: [ - { _id: new ObjectId('642d766b7300158b1f22e972') }, - ], - requestId: 'abc', - enableStorage: false, - }, - mockConnectionInfo - ); - - expect(fetchStub).to.have.been.calledOnce; - - const { args } = fetchStub.firstCall; - - expect(args[0]).to.eq(expectedEndpoints[aiEndpoint]); - expect(args[1].body).to.eq( - '{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":{"$oid":"642d766b7300158b1f22e972"}}],"enableStorage":false}' - ); - expect(res).to.deep.eq(responses.success); - }); - - it('should fail when response is not matching expected schema', async function () { - for (const [res, error] of responses.invalid) { - const fetchStub = sandbox.stub().resolves(makeResponse(res)); - global.fetch = fetchStub; - - try { - await atlasAiService[functionName]( - { - userInput: 'test', - collectionName: 'test', - databaseName: 'peanut', - requestId: 'abc', - signal: new AbortController().signal, - enableStorage: false, - }, - mockConnectionInfo - ); - expect.fail(`Expected ${functionName} to throw`); - } catch (err) { - expect((err as Error).message).to.match( - new RegExp(error, 'i') - ); - } - } - }); - - it('throws if the request would be too much for the ai', async function () { - try { - await atlasAiService[functionName]( - { - userInput: 'test', - collectionName: 'test', - databaseName: 'peanut', - sampleDocuments: [{ test: '4'.repeat(5120001) }], - requestId: 'abc', - signal: new AbortController().signal, - enableStorage: false, - }, - mockConnectionInfo - ); - expect.fail(`Expected ${functionName} to throw`); - } catch (err) { - expect(err).to.have.property( - 'message', - 'Sorry, your request is too large. Please use a smaller prompt or try using this feature on a collection with smaller documents.' - ); - } - }); - - it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { - const fetchStub = sandbox - .stub() - .resolves(makeResponse(responses.success)); - global.fetch = fetchStub; - - await atlasAiService[functionName]( - { - userInput: 'test', - collectionName: 'test.test', - databaseName: 'peanut', - sampleDocuments: [ - { a: '1' }, - { a: '2' }, - { a: '3' }, - { a: '4'.repeat(5120001) }, - ], - requestId: 'abc', - signal: new AbortController().signal, - enableStorage: false, - }, - mockConnectionInfo - ); - - const { args } = fetchStub.firstCall; - - expect(fetchStub).to.have.been.calledOnce; - expect(args[1].body).to.eq( - '{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}],"enableStorage":false}' - ); - }); - }); - } - }); - describe('setupAIAccess', function () { beforeEach(async function () { await preferences.savePreferences({ @@ -741,377 +540,370 @@ describe('AtlasAiService', function () { }); }); } - - describe('with chatbot api', function () { - describe('getQueryFromUserInput and getAggregationFromUserInput', function () { - type Chunk = { type: 'text' | 'error'; content: string }; - let atlasAiService: AtlasAiService; - const mockConnectionInfo = getMockConnectionInfo(); - - function streamChunkResponse( - readableStreamController: ReadableStreamController, - chunks: Chunk[] - ) { - const responseId = `resp_${Date.now()}`; - const itemId = `item_${Date.now()}`; - let sequenceNumber = 0; - - const encoder = new TextEncoder(); - - // openai response format: - // https://github.com/vercel/ai/blob/811119c1808d7b62a4857bcad42353808cdba17c/packages/openai/src/responses/openai-responses-api.ts#L322 - - // Send response.created event - readableStreamController.enqueue( - encoder.encode( - `data: ${JSON.stringify({ - type: 'response.created', - response: { - id: responseId, - object: 'realtime.response', - status: 'in_progress', - output: [], - usage: { - input_tokens: 0, - output_tokens: 0, - total_tokens: 0, - }, - }, - sequence_number: sequenceNumber++, - })}\n\n` - ) - ); - - // Send output_item.added event - readableStreamController.enqueue( - encoder.encode( - `data: ${JSON.stringify({ - type: 'response.output_item.added', - response_id: responseId, - output_index: 0, - item: { - id: itemId, - object: 'realtime.item', - type: 'message', - role: 'assistant', - content: [], - }, - sequence_number: sequenceNumber++, - })}\n\n` - ) - ); - - for (const chunk of chunks) { - if (chunk.type === 'error') { - readableStreamController.enqueue( - encoder.encode( - `data: ${JSON.stringify({ - type: `error`, - response_id: responseId, - item_id: itemId, - output_index: 0, - error: { - type: 'model_error', - code: 'model_error', - message: chunk.content, - }, - sequence_number: sequenceNumber++, - })}\n\n` - ) - ); - } else { - readableStreamController.enqueue( - encoder.encode( - `data: ${JSON.stringify({ - type: 'response.output_text.delta', - response_id: responseId, - item_id: itemId, - output_index: 0, - delta: chunk.content, - sequence_number: sequenceNumber++, - })}\n\n` - ) - ); - } - } - - const content = chunks - .filter((c) => c.type === 'text') - .map((c) => c.content) - .join(''); - - // Send output_item.done event - readableStreamController.enqueue( - encoder.encode( - `data: ${JSON.stringify({ - type: 'response.output_item.done', - response_id: responseId, - output_index: 0, - item: { - id: itemId, - object: 'realtime.item', - type: 'message', - role: 'assistant', - content: [ - { - type: 'text', - text: content, - }, - ], + describe('getQueryFromUserInput and getAggregationFromUserInput', function () { + type Chunk = { type: 'text' | 'error'; content: string }; + let atlasAiService: AtlasAiService; + const mockConnectionInfo = getMockConnectionInfo(); + + function streamChunkResponse( + readableStreamController: ReadableStreamController, + chunks: Chunk[] + ) { + const responseId = `resp_${Date.now()}`; + const itemId = `item_${Date.now()}`; + let sequenceNumber = 0; + + const encoder = new TextEncoder(); + + // openai response format: + // https://github.com/vercel/ai/blob/811119c1808d7b62a4857bcad42353808cdba17c/packages/openai/src/responses/openai-responses-api.ts#L322 + + // Send response.created event + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.created', + response: { + id: responseId, + object: 'realtime.response', + status: 'in_progress', + output: [], + usage: { + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, }, - sequence_number: sequenceNumber++, - })}\n\n` - ) - ); - - // Send response.completed event - const tokenCount = Math.ceil(content.length / 4); // assume 4 chars per token - readableStreamController.enqueue( - encoder.encode( - `data: ${JSON.stringify({ - type: 'response.completed', - response: { - id: responseId, - object: 'realtime.response', - status: 'completed', - output: [ - { - id: itemId, - object: 'realtime.item', - type: 'message', - role: 'assistant', - content: [ - { - type: 'text', - text: content, - }, - ], - }, - ], - usage: { - input_tokens: 10, - output_tokens: tokenCount, - total_tokens: 10 + tokenCount, + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + + // Send output_item.added event + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.output_item.added', + response_id: responseId, + output_index: 0, + item: { + id: itemId, + object: 'realtime.item', + type: 'message', + role: 'assistant', + content: [], + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + + for (const chunk of chunks) { + if (chunk.type === 'error') { + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: `error`, + response_id: responseId, + item_id: itemId, + output_index: 0, + error: { + type: 'model_error', + code: 'model_error', + message: chunk.content, }, - }, - sequence_number: sequenceNumber++, - })}\n\n` - ) - ); - } - - function streamableFetchMock(chunks: Chunk[]) { - const readableStream = new ReadableStream({ - start(controller) { - streamChunkResponse(controller, chunks); - controller.close(); - }, - }); - return new Response(readableStream, { - headers: { 'Content-Type': 'text/event-stream' }, - }); + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + } else { + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.output_text.delta', + response_id: responseId, + item_id: itemId, + output_index: 0, + delta: chunk.content, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + } } - beforeEach(async function () { - const mockAtlasService = new MockAtlasService(); - await preferences.savePreferences({ - enableChatbotEndpointForGenAI: true, - telemetryAtlasUserId: '1234', - }); - atlasAiService = new AtlasAiService({ - apiURLPreset: 'cloud', - atlasService: mockAtlasService as any, - preferences, - logger: createNoopLogger(), - }); - // Enable the AI feature - const fetchStub = sandbox.stub().resolves( - makeResponse({ - features: { - GEN_AI_COMPASS: { - enabled: true, + const content = chunks + .filter((c) => c.type === 'text') + .map((c) => c.content) + .join(''); + + // Send output_item.done event + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.output_item.done', + response_id: responseId, + output_index: 0, + item: { + id: itemId, + object: 'realtime.item', + type: 'message', + role: 'assistant', + content: [ + { + type: 'text', + text: content, + }, + ], + }, + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + + // Send response.completed event + const tokenCount = Math.ceil(content.length / 4); // assume 4 chars per token + readableStreamController.enqueue( + encoder.encode( + `data: ${JSON.stringify({ + type: 'response.completed', + response: { + id: responseId, + object: 'realtime.response', + status: 'completed', + output: [ + { + id: itemId, + object: 'realtime.item', + type: 'message', + role: 'assistant', + content: [ + { + type: 'text', + text: content, + }, + ], + }, + ], + usage: { + input_tokens: 10, + output_tokens: tokenCount, + total_tokens: 10 + tokenCount, }, }, - }) - ); - global.fetch = fetchStub; - await atlasAiService['setupAIAccess'](); + sequence_number: sequenceNumber++, + })}\n\n` + ) + ); + } + + function streamableFetchMock(chunks: Chunk[]) { + const readableStream = new ReadableStream({ + start(controller) { + streamChunkResponse(controller, chunks); + controller.close(); + }, }); - - after(function () { - global.fetch = initialFetch; + return new Response(readableStream, { + headers: { 'Content-Type': 'text/event-stream' }, + }); + } + + beforeEach(async function () { + const mockAtlasService = new MockAtlasService(); + atlasAiService = new AtlasAiService({ + apiURLPreset: 'cloud', + atlasService: mockAtlasService as any, + preferences, + logger: createNoopLogger(), }); + // Enable the AI feature + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: { + GEN_AI_COMPASS: { + enabled: true, + }, + }, + }) + ); + global.fetch = fetchStub; + await atlasAiService['setupAIAccess'](); + }); - const testCases = [ - { - functionName: 'getQueryFromUserInput', - successResponse: { - request: [ - { type: 'text', content: 'Hello' }, - { type: 'text', content: ' world' }, - { - type: 'text', - content: '. This is some non relevant text in the output', + after(function () { + global.fetch = initialFetch; + }); + + const testCases = [ + { + functionName: 'getQueryFromUserInput', + successResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world' }, + { + type: 'text', + content: '. This is some non relevant text in the output', + }, + { type: 'text', content: '{test: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}' }, + ] as Chunk[], + response: { + content: { + aggregation: { + pipeline: '', }, - { type: 'text', content: '{test: ' }, - { type: 'text', content: '"pineapple"' }, - { type: 'text', content: '}' }, - ] as Chunk[], - response: { - content: { - aggregation: { - pipeline: '', - }, - query: { - filter: "{test:'pineapple'}", - project: null, - sort: null, - skip: null, - limit: null, - }, + query: { + filter: "{test:'pineapple'}", + project: null, + sort: null, + skip: null, + limit: null, }, }, }, - invalidModelResponse: { - request: [ - { type: 'text', content: 'Hello' }, - { type: 'text', content: ' world.' }, - { type: 'text', content: '{test: ' }, - { type: 'text', content: '"pineapple"' }, - { type: 'text', content: '}' }, - { type: 'error', content: 'Model crashed!' }, - ] as Chunk[], - errorMessage: 'Model crashed!', - }, }, - { - functionName: 'getAggregationFromUserInput', - successResponse: { - request: [ - { type: 'text', content: 'Hello' }, - { type: 'text', content: ' world' }, - { - type: 'text', - content: '. This is some non relevant text in the output', - }, - { type: 'text', content: '[{$count: ' }, - { type: 'text', content: '"pineapple"' }, - { type: 'text', content: '}]' }, - ] as Chunk[], - response: { - content: { - aggregation: { - pipeline: "[{$count:'pineapple'}]", - }, + invalidModelResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world.' }, + { type: 'text', content: '{test: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}' }, + { type: 'error', content: 'Model crashed!' }, + ] as Chunk[], + errorMessage: 'Model crashed!', + }, + }, + { + functionName: 'getAggregationFromUserInput', + successResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world' }, + { + type: 'text', + content: '. This is some non relevant text in the output', + }, + { type: 'text', content: '[{$count: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}]' }, + ] as Chunk[], + response: { + content: { + aggregation: { + pipeline: "[{$count:'pineapple'}]", }, }, }, - invalidModelResponse: { - request: [ - { type: 'text', content: 'Hello' }, - { type: 'text', content: ' world.' }, - { type: 'text', content: '[{test: ' }, - { type: 'text', content: '"pineapple"' }, - { type: 'text', content: '}]' }, - { type: 'error', content: 'Model crashed!' }, - ] as Chunk[], - errorMessage: 'Model crashed!', - }, }, - ] as const; - - for (const { - functionName, - successResponse, - invalidModelResponse, - } of testCases) { - describe(functionName, function () { - it('makes a post request with the user input to the endpoint in the environment', async function () { - const fetchStub = sandbox - .stub() - .resolves(streamableFetchMock(successResponse.request)); - global.fetch = fetchStub; - - const input = { - userInput: 'test', - signal: new AbortController().signal, - collectionName: 'jam', - databaseName: 'peanut', - schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, - sampleDocuments: [ - { _id: new ObjectId('642d766b7300158b1f22e972') }, - ], - requestId: 'abc', - enableStorage: true, - }; + invalidModelResponse: { + request: [ + { type: 'text', content: 'Hello' }, + { type: 'text', content: ' world.' }, + { type: 'text', content: '[{test: ' }, + { type: 'text', content: '"pineapple"' }, + { type: 'text', content: '}]' }, + { type: 'error', content: 'Model crashed!' }, + ] as Chunk[], + errorMessage: 'Model crashed!', + }, + }, + ] as const; + + for (const { + functionName, + successResponse, + invalidModelResponse, + } of testCases) { + describe(functionName, function () { + it('makes a post request with the user input to the endpoint in the environment', async function () { + const fetchStub = sandbox + .stub() + .resolves(streamableFetchMock(successResponse.request)); + global.fetch = fetchStub; - const res = await atlasAiService[functionName]( - input as any, - mockConnectionInfo - ); + const input = { + userInput: 'test', + signal: new AbortController().signal, + collectionName: 'jam', + databaseName: 'peanut', + schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, + sampleDocuments: [ + { _id: new ObjectId('642d766b7300158b1f22e972') }, + ], + requestId: 'abc', + enableStorage: true, + }; + + const res = await atlasAiService[functionName]( + input as any, + mockConnectionInfo + ); - expect(fetchStub).to.have.been.calledOnce; + expect(fetchStub).to.have.been.calledOnce; - const { args } = fetchStub.firstCall; + const { args } = fetchStub.firstCall; + const requestHeaders = args[1].headers as Record; + expect(requestHeaders['x-client-request-id']).to.equal( + input.requestId + ); + expect(requestHeaders['entrypoint']).to.equal( + 'natural-language-to-mql' + ); + const requestBody = JSON.parse(args[1].body as string); + const { userId, ...restOfMetadata } = requestBody.metadata; + expect(restOfMetadata).to.deep.equal({ + store: 'true', + sensitiveStorage: 'sensitive', + }); + expect(userId).to.be.a('string').that.is.not.empty; + expect(requestBody.instructions).to.be.a('string'); + expect(requestBody.input).to.be.an('array'); + + const { role, content } = requestBody.input[0]; + expect(role).to.equal('user'); + expect(content[0].text).to.include( + `Database name: "${input.databaseName}"` + ); + expect(content[0].text).to.include( + `Collection name: "${input.collectionName}"` + ); + expect(content[0].text).to.include( + `_id: 'ObjectId`, + 'includes schema information in the prompt' + ); + expect(res).to.deep.eq(successResponse.response); + }); - const requestHeaders = args[1].headers as Record; - expect(requestHeaders['x-client-request-id']).to.equal( - input.requestId - ); + it('should throw an error when the stream contains an error chunk', async function () { + const fetchStub = sandbox + .stub() + .resolves(streamableFetchMock(invalidModelResponse.request)); + global.fetch = fetchStub; - const requestBody = JSON.parse(args[1].body as string); - expect(requestBody.model).to.equal('mongodb-chat-latest'); - const { userId, ...restOfMetadata } = requestBody.metadata; - expect(restOfMetadata).to.deep.equal({ - store: 'true', - sensitiveStorage: 'sensitive', - }); - expect(userId).to.be.a('string').that.is.not.empty; - expect(requestBody.instructions).to.be.a('string'); - expect(requestBody.input).to.be.an('array'); - - const { role, content } = requestBody.input[0]; - expect(role).to.equal('user'); - expect(content[0].text).to.include( - `Database name: "${input.databaseName}"` - ); - expect(content[0].text).to.include( - `Collection name: "${input.collectionName}"` + try { + await atlasAiService[functionName]( + { + userInput: 'test', + collectionName: 'test', + databaseName: 'peanut', + requestId: 'abc', + signal: new AbortController().signal, + enableStorage: true, + }, + mockConnectionInfo ); - expect(content[0].text).to.include( - `_id: 'ObjectId`, - 'includes schema information in the prompt' + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect((err as Error).message).to.match( + new RegExp(invalidModelResponse.errorMessage, 'i') ); - expect(res).to.deep.eq(successResponse.response); - }); - - it('should throw an error when the stream contains an error chunk', async function () { - const fetchStub = sandbox - .stub() - .resolves(streamableFetchMock(invalidModelResponse.request)); - global.fetch = fetchStub; - - try { - await atlasAiService[functionName]( - { - userInput: 'test', - collectionName: 'test', - databaseName: 'peanut', - requestId: 'abc', - signal: new AbortController().signal, - enableStorage: false, - }, - mockConnectionInfo - ); - expect.fail(`Expected ${functionName} to throw`); - } catch (err) { - expect((err as Error).message).to.match( - new RegExp(invalidModelResponse.errorMessage, 'i') - ); - } - }); + } }); - } - }); + }); + } }); }); diff --git a/packages/compass-generative-ai/src/atlas-ai-service.ts b/packages/compass-generative-ai/src/atlas-ai-service.ts index 01165b4310d..6fa1b3e7a09 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.ts @@ -344,9 +344,7 @@ export class AtlasAiService { }, }); }, - // TODO(COMPASS-10125): Switch the model to `mongodb-slim-latest` when - // enabling this feature (to use edu-chatbot for GenAI). - }).responses('mongodb-chat-latest'); + }).responses('mongodb-slim-latest'); } /** diff --git a/packages/compass-preferences-model/src/feature-flags.ts b/packages/compass-preferences-model/src/feature-flags.ts index d3d1ebfcb92..3771376e526 100644 --- a/packages/compass-preferences-model/src/feature-flags.ts +++ b/packages/compass-preferences-model/src/feature-flags.ts @@ -226,7 +226,7 @@ export const FEATURE_FLAG_DEFINITIONS = [ }, { name: 'enableChatbotEndpointForGenAI', - stage: 'development', + stage: 'released', atlasCloudFeatureFlagName: null, description: { short: 'Enable Chatbot API for Generative AI',