Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/handlers/imageToJsonHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { CONTENT_TYPES } from '../globals';

export async function getTransformedResponse(_: {
response: Response;
transformer: (base64Image: string) => Record<string, any>;
}): Promise<Response | undefined> {
const { response, transformer } = _;
try {
console.info(
'imageToJsonHandler > converting image response to base64 JSON'
);
const imageBuffer = await response.arrayBuffer();
// Simple ArrayBuffer to base64 conversion for environments like Cloudflare Workers
let binary = '';
const bytes = new Uint8Array(imageBuffer);
const len = bytes.byteLength;
for (let i = 0; i < len; i++) {
binary += String.fromCharCode(bytes[i]);
}
return new Response(JSON.stringify(transformer(btoa(binary))), {
headers: {
...Object.fromEntries(response.headers), // keep original headers
'content-type': CONTENT_TYPES.APPLICATION_JSON,
},
status: response.status,
statusText: response.statusText,
});
} catch (error) {
console.error(
'imageToJsonHandler > error converting image response to base64 JSON',
error
);
return response;
}
}
17 changes: 16 additions & 1 deletion src/handlers/responseHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ export async function responseHandler(
originalResponseJson?: Record<string, any> | null;
}> {
let responseTransformerFunction: Function | undefined;
const responseContentType = response.headers?.get('content-type');
const isSuccessStatusCode = [200, 246].includes(response.status);

if (typeof provider == 'object') {
Expand All @@ -73,6 +72,22 @@ export async function responseHandler(
responseTransformerFunction = providerTransformers?.[responseTransformer];
}

// if the original response is an image, convert it to JSON if the provider has a transformer
if (
responseTransformer === 'imageGenerate' && // check that we are on the imageGenerate route
responseTransformerFunction && // check that we have a transformer for this provider
providerTransformers?.[`imageToJson`] && // check that we have a 'imageToJson" transformer for this provider
response.headers
?.get('content-type')
?.startsWith(CONTENT_TYPES.GENERIC_IMAGE_PATTERN) // check that the original response content type is an image
) {
// transformers are async, because we read the body as an array buffer
response = await providerTransformers?.[`imageToJson`](response);
}

// read the final content type after transformations
const responseContentType = response.headers?.get('content-type');

// JSON to text/event-stream conversion is only allowed for unified routes: chat completions and completions.
// Set the transformer to OpenAI json to stream convertor function in that case.
if (responseTransformer && streamingMode && isCacheHit) {
Expand Down
12 changes: 12 additions & 0 deletions src/providers/segmind/imageGenerate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';
import { getTransformedResponse } from '../../handlers/imageToJsonHandler';

export const SegmindImageGenerateConfig: ProviderConfig = {
prompt: {
Expand Down Expand Up @@ -109,6 +110,17 @@ interface SegmindImageGenerateErrorResponse {
error?: string;
}

export const SegmindImageToJsonResponseTransform = async (
response: Response
) => {
return getTransformedResponse({
response,
transformer: (base64Image: string) => ({
image: base64Image,
}),
});
};

export const SegmindImageGenerateResponseTransform: (
response: SegmindImageGenerateResponse | SegmindImageGenerateErrorResponse,
responseStatus: number
Expand Down
2 changes: 2 additions & 0 deletions src/providers/segmind/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ import SegmindAIAPIConfig from './api';
import {
SegmindImageGenerateConfig,
SegmindImageGenerateResponseTransform,
SegmindImageToJsonResponseTransform,
} from './imageGenerate';

const SegmindConfig: ProviderConfigs = {
api: SegmindAIAPIConfig,
imageGenerate: SegmindImageGenerateConfig,
responseTransforms: {
imageGenerate: SegmindImageGenerateResponseTransform,
imageToJson: SegmindImageToJsonResponseTransform,
},
};

Expand Down