File size: 4,914 Bytes
6927c07 3b8d251 2a3d5f5 70f88f9 2a3d5f5 5d4b860 6494f5a 6927c07 7ebc805 2a29fbb 7ebc805 6494f5a fcb61ba c575ee3 2327de3 c575ee3 2327de3 c575ee3 6fb59d2 3a36a44 2327de3 ea5c624 70f88f9 3a36a44 a544611 2cb3f09 2327de3 7efad13 5d4b860 c575ee3 cae55a7 6927c07 225b553 6927c07 cae55a7 fcb61ba 6494f5a fcb61ba 3b8d251 225b553 cae55a7 6494f5a 3b8d251 6494f5a cae55a7 f4987a4 cae55a7 f4987a4 6494f5a f4987a4 cae55a7 da37d94 70f88f9 3a36a44 da37d94 cae55a7 6494f5a cae55a7 da37d94 70f88f9 3a36a44 da37d94 cae55a7 fcb61ba cae55a7 7465cab 7efad13 6494f5a c575ee3 a544611 2327de3 a544611 6927c07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { createDataStream } from 'ai';
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts';
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
import SwitchableStream from '~/lib/.server/llm/switchable-stream';
import type { IProviderSetting } from '~/types/model';
import { createScopedLogger } from '~/utils/logger';
export async function action(args: ActionFunctionArgs) {
return chatAction(args);
}
const logger = createScopedLogger('api.chat');
function parseCookies(cookieHeader: string): Record<string, string> {
const cookies: Record<string, string> = {};
const items = cookieHeader.split(';').map((cookie) => cookie.trim());
items.forEach((item) => {
const [name, ...rest] = item.split('=');
if (name && rest) {
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join('=').trim());
cookies[decodedName] = decodedValue;
}
});
return cookies;
}
async function chatAction({ context, request }: ActionFunctionArgs) {
const { messages, files, promptId, contextOptimization } = await request.json<{
messages: Messages;
files: any;
promptId?: string;
contextOptimization: boolean;
}>();
const cookieHeader = request.headers.get('Cookie');
const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
const providerSettings: Record<string, IProviderSetting> = JSON.parse(
parseCookies(cookieHeader || '').providers || '{}',
);
const stream = new SwitchableStream();
const cumulativeUsage = {
completionTokens: 0,
promptTokens: 0,
totalTokens: 0,
};
try {
const options: StreamingOptions = {
toolChoice: 'none',
onFinish: async ({ text: content, finishReason, usage }) => {
logger.debug('usage', JSON.stringify(usage));
if (usage) {
cumulativeUsage.completionTokens += usage.completionTokens || 0;
cumulativeUsage.promptTokens += usage.promptTokens || 0;
cumulativeUsage.totalTokens += usage.totalTokens || 0;
}
if (finishReason !== 'length') {
const encoder = new TextEncoder();
const usageStream = createDataStream({
async execute(dataStream) {
dataStream.writeMessageAnnotation({
type: 'usage',
value: {
completionTokens: cumulativeUsage.completionTokens,
promptTokens: cumulativeUsage.promptTokens,
totalTokens: cumulativeUsage.totalTokens,
},
});
},
onError: (error: any) => `Custom error: ${error.message}`,
}).pipeThrough(
new TransformStream({
transform: (chunk, controller) => {
// Convert the string stream to a byte stream
const str = typeof chunk === 'string' ? chunk : JSON.stringify(chunk);
controller.enqueue(encoder.encode(str));
},
}),
);
await stream.switchSource(usageStream);
await new Promise((resolve) => setTimeout(resolve, 0));
stream.close();
return;
}
if (stream.switches >= MAX_RESPONSE_SEGMENTS) {
throw Error('Cannot continue message: Maximum segments reached');
}
const switchesLeft = MAX_RESPONSE_SEGMENTS - stream.switches;
logger.info(`Reached max token limit (${MAX_TOKENS}): Continuing message (${switchesLeft} switches left)`);
messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT });
const result = await streamText({
messages,
env: context.cloudflare.env,
options,
apiKeys,
files,
providerSettings,
promptId,
contextOptimization,
});
stream.switchSource(result.toDataStream());
return;
},
};
const result = await streamText({
messages,
env: context.cloudflare.env,
options,
apiKeys,
files,
providerSettings,
promptId,
contextOptimization,
});
stream.switchSource(result.toDataStream());
return new Response(stream.readable, {
status: 200,
headers: {
contentType: 'text/plain; charset=utf-8',
},
});
} catch (error: any) {
logger.error(error);
if (error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', {
status: 401,
statusText: 'Unauthorized',
});
}
throw new Response(null, {
status: 500,
statusText: 'Internal Server Error',
});
}
}
|