codacus commited on
Commit
2e49905
·
unverified ·
2 Parent(s): a0eb0a0 7efad13

Merge pull request #513 from thecodacus/together-ai-dynamic-model-list

Browse files
app/lib/.server/llm/stream-text.ts CHANGED
@@ -1,11 +1,8 @@
1
- // eslint-disable-next-line @typescript-eslint/ban-ts-comment
2
- // @ts-nocheck – TODO: Provider proper types
3
-
4
  import { convertToCoreMessages, streamText as _streamText } from 'ai';
5
  import { getModel } from '~/lib/.server/llm/model';
6
  import { MAX_TOKENS } from './constants';
7
  import { getSystemPrompt } from './prompts';
8
- import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
9
 
10
  interface ToolResult<Name extends string, Args, Result> {
11
  toolCallId: string;
@@ -43,7 +40,7 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
43
  * Extract provider
44
  * const providerMatch = message.content.match(PROVIDER_REGEX);
45
  */
46
- const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;
47
 
48
  const cleanedContent = Array.isArray(message.content)
49
  ? message.content.map((item) => {
@@ -61,10 +58,15 @@ function extractPropertiesFromMessage(message: Message): { model: string; provid
61
  return { model, provider, content: cleanedContent };
62
  }
63
 
64
- export function streamText(messages: Messages, env: Env, options?: StreamingOptions, apiKeys?: Record<string, string>) {
 
 
 
 
 
65
  let currentModel = DEFAULT_MODEL;
66
- let currentProvider = DEFAULT_PROVIDER;
67
-
68
  const processedMessages = messages.map((message) => {
69
  if (message.role === 'user') {
70
  const { model, provider, content } = extractPropertiesFromMessage(message);
@@ -86,10 +88,10 @@ export function streamText(messages: Messages, env: Env, options?: StreamingOpti
86
  const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
87
 
88
  return _streamText({
89
- ...options,
90
- model: getModel(currentProvider, currentModel, env, apiKeys),
91
  system: getSystemPrompt(),
92
  maxTokens: dynamicMaxTokens,
93
- messages: convertToCoreMessages(processedMessages),
 
94
  });
95
  }
 
 
 
 
1
  import { convertToCoreMessages, streamText as _streamText } from 'ai';
2
  import { getModel } from '~/lib/.server/llm/model';
3
  import { MAX_TOKENS } from './constants';
4
  import { getSystemPrompt } from './prompts';
5
+ import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
6
 
7
  interface ToolResult<Name extends string, Args, Result> {
8
  toolCallId: string;
 
40
  * Extract provider
41
  * const providerMatch = message.content.match(PROVIDER_REGEX);
42
  */
43
+ const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER.name;
44
 
45
  const cleanedContent = Array.isArray(message.content)
46
  ? message.content.map((item) => {
 
58
  return { model, provider, content: cleanedContent };
59
  }
60
 
61
+ export async function streamText(
62
+ messages: Messages,
63
+ env: Env,
64
+ options?: StreamingOptions,
65
+ apiKeys?: Record<string, string>,
66
+ ) {
67
  let currentModel = DEFAULT_MODEL;
68
+ let currentProvider = DEFAULT_PROVIDER.name;
69
+ const MODEL_LIST = await getModelList(apiKeys || {});
70
  const processedMessages = messages.map((message) => {
71
  if (message.role === 'user') {
72
  const { model, provider, content } = extractPropertiesFromMessage(message);
 
88
  const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
89
 
90
  return _streamText({
91
+ model: getModel(currentProvider, currentModel, env, apiKeys) as any,
 
92
  system: getSystemPrompt(),
93
  maxTokens: dynamicMaxTokens,
94
+ messages: convertToCoreMessages(processedMessages as any),
95
+ ...options,
96
  });
97
  }
app/routes/api.chat.ts CHANGED
@@ -1,6 +1,3 @@
1
- // eslint-disable-next-line @typescript-eslint/ban-ts-comment
2
- // @ts-nocheck – TODO: Provider proper types
3
-
4
  import { type ActionFunctionArgs } from '@remix-run/cloudflare';
5
  import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
6
  import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
@@ -11,8 +8,8 @@ export async function action(args: ActionFunctionArgs) {
11
  return chatAction(args);
12
  }
13
 
14
- function parseCookies(cookieHeader) {
15
- const cookies = {};
16
 
17
  // Split the cookie string by semicolons and spaces
18
  const items = cookieHeader.split(';').map((cookie) => cookie.trim());
@@ -32,7 +29,7 @@ function parseCookies(cookieHeader) {
32
  }
33
 
34
  async function chatAction({ context, request }: ActionFunctionArgs) {
35
- const { messages, model } = await request.json<{
36
  messages: Messages;
37
  model: string;
38
  }>();
@@ -40,15 +37,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
40
  const cookieHeader = request.headers.get('Cookie');
41
 
42
  // Parse the cookie's value (returns an object or null if no cookie exists)
43
- const apiKeys = JSON.parse(parseCookies(cookieHeader).apiKeys || '{}');
44
 
45
  const stream = new SwitchableStream();
46
 
47
  try {
48
  const options: StreamingOptions = {
49
  toolChoice: 'none',
50
- apiKeys,
51
- model,
52
  onFinish: async ({ text: content, finishReason }) => {
53
  if (finishReason !== 'length') {
54
  return stream.close();
@@ -65,7 +60,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
65
  messages.push({ role: 'assistant', content });
66
  messages.push({ role: 'user', content: CONTINUE_PROMPT });
67
 
68
- const result = await streamText(messages, context.cloudflare.env, options);
69
 
70
  return stream.switchSource(result.toAIStream());
71
  },
@@ -81,7 +76,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
81
  contentType: 'text/plain; charset=utf-8',
82
  },
83
  });
84
- } catch (error) {
85
  console.log(error);
86
 
87
  if (error.message?.includes('API key')) {
 
 
 
 
1
  import { type ActionFunctionArgs } from '@remix-run/cloudflare';
2
  import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
3
  import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
 
8
  return chatAction(args);
9
  }
10
 
11
+ function parseCookies(cookieHeader: string) {
12
+ const cookies: any = {};
13
 
14
  // Split the cookie string by semicolons and spaces
15
  const items = cookieHeader.split(';').map((cookie) => cookie.trim());
 
29
  }
30
 
31
  async function chatAction({ context, request }: ActionFunctionArgs) {
32
+ const { messages } = await request.json<{
33
  messages: Messages;
34
  model: string;
35
  }>();
 
37
  const cookieHeader = request.headers.get('Cookie');
38
 
39
  // Parse the cookie's value (returns an object or null if no cookie exists)
40
+ const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}');
41
 
42
  const stream = new SwitchableStream();
43
 
44
  try {
45
  const options: StreamingOptions = {
46
  toolChoice: 'none',
 
 
47
  onFinish: async ({ text: content, finishReason }) => {
48
  if (finishReason !== 'length') {
49
  return stream.close();
 
60
  messages.push({ role: 'assistant', content });
61
  messages.push({ role: 'user', content: CONTINUE_PROMPT });
62
 
63
+ const result = await streamText(messages, context.cloudflare.env, options, apiKeys);
64
 
65
  return stream.switchSource(result.toAIStream());
66
  },
 
76
  contentType: 'text/plain; charset=utf-8',
77
  },
78
  });
79
+ } catch (error: any) {
80
  console.log(error);
81
 
82
  if (error.message?.includes('API key')) {
app/types/model.ts CHANGED
@@ -3,7 +3,7 @@ import type { ModelInfo } from '~/utils/types';
3
  export type ProviderInfo = {
4
  staticModels: ModelInfo[];
5
  name: string;
6
- getDynamicModels?: () => Promise<ModelInfo[]>;
7
  getApiKeyLink?: string;
8
  labelForGetApiKey?: string;
9
  icon?: string;
 
3
  export type ProviderInfo = {
4
  staticModels: ModelInfo[];
5
  name: string;
6
+ getDynamicModels?: (apiKeys?: Record<string, string>) => Promise<ModelInfo[]>;
7
  getApiKeyLink?: string;
8
  labelForGetApiKey?: string;
9
  icon?: string;
app/utils/constants.ts CHANGED
@@ -1,3 +1,4 @@
 
1
  import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
2
  import type { ProviderInfo } from '~/types/model';
3
 
@@ -262,6 +263,7 @@ const PROVIDER_LIST: ProviderInfo[] = [
262
  },
263
  {
264
  name: 'Together',
 
265
  staticModels: [
266
  {
267
  name: 'Qwen/Qwen2.5-Coder-32B-Instruct',
@@ -293,6 +295,61 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(
293
 
294
  export let MODEL_LIST: ModelInfo[] = [...staticModels];
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  const getOllamaBaseUrl = () => {
297
  const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
298
 
@@ -340,7 +397,14 @@ async function getOpenAILikeModels(): Promise<ModelInfo[]> {
340
  return [];
341
  }
342
 
343
- const apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
 
 
 
 
 
 
 
344
  const response = await fetch(`${baseUrl}/models`, {
345
  headers: {
346
  Authorization: `Bearer ${apiKey}`,
@@ -414,16 +478,32 @@ async function getLMStudioModels(): Promise<ModelInfo[]> {
414
  }
415
 
416
  async function initializeModelList(): Promise<ModelInfo[]> {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  MODEL_LIST = [
418
  ...(
419
  await Promise.all(
420
  PROVIDER_LIST.filter(
421
  (p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
422
- ).map((p) => p.getDynamicModels()),
423
  )
424
  ).flat(),
425
  ...staticModels,
426
  ];
 
427
  return MODEL_LIST;
428
  }
429
 
 
1
+ import Cookies from 'js-cookie';
2
  import type { ModelInfo, OllamaApiResponse, OllamaModel } from './types';
3
  import type { ProviderInfo } from '~/types/model';
4
 
 
263
  },
264
  {
265
  name: 'Together',
266
+ getDynamicModels: getTogetherModels,
267
  staticModels: [
268
  {
269
  name: 'Qwen/Qwen2.5-Coder-32B-Instruct',
 
295
 
296
  export let MODEL_LIST: ModelInfo[] = [...staticModels];
297
 
298
+ export async function getModelList(apiKeys: Record<string, string>) {
299
+ MODEL_LIST = [
300
+ ...(
301
+ await Promise.all(
302
+ PROVIDER_LIST.filter(
303
+ (p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
304
+ ).map((p) => p.getDynamicModels(apiKeys)),
305
+ )
306
+ ).flat(),
307
+ ...staticModels,
308
+ ];
309
+ return MODEL_LIST;
310
+ }
311
+
312
+ async function getTogetherModels(apiKeys?: Record<string, string>): Promise<ModelInfo[]> {
313
+ try {
314
+ const baseUrl = import.meta.env.TOGETHER_API_BASE_URL || '';
315
+ const provider = 'Together';
316
+
317
+ if (!baseUrl) {
318
+ return [];
319
+ }
320
+
321
+ let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
322
+
323
+ if (apiKeys && apiKeys[provider]) {
324
+ apiKey = apiKeys[provider];
325
+ }
326
+
327
+ if (!apiKey) {
328
+ return [];
329
+ }
330
+
331
+ const response = await fetch(`${baseUrl}/models`, {
332
+ headers: {
333
+ Authorization: `Bearer ${apiKey}`,
334
+ },
335
+ });
336
+ const res = (await response.json()) as any;
337
+ const data: any[] = (res || []).filter((model: any) => model.type == 'chat');
338
+
339
+ return data.map((m: any) => ({
340
+ name: m.id,
341
+ label: `${m.display_name} - in:$${m.pricing.input.toFixed(
342
+ 2,
343
+ )} out:$${m.pricing.output.toFixed(2)} - context ${Math.floor(m.context_length / 1000)}k`,
344
+ provider,
345
+ maxTokenAllowed: 8000,
346
+ }));
347
+ } catch (e) {
348
+ console.error('Error getting OpenAILike models:', e);
349
+ return [];
350
+ }
351
+ }
352
+
353
  const getOllamaBaseUrl = () => {
354
  const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
355
 
 
397
  return [];
398
  }
399
 
400
+ let apiKey = import.meta.env.OPENAI_LIKE_API_KEY ?? '';
401
+
402
+ const apikeys = JSON.parse(Cookies.get('apiKeys') || '{}');
403
+
404
+ if (apikeys && apikeys.OpenAILike) {
405
+ apiKey = apikeys.OpenAILike;
406
+ }
407
+
408
  const response = await fetch(`${baseUrl}/models`, {
409
  headers: {
410
  Authorization: `Bearer ${apiKey}`,
 
478
  }
479
 
480
  async function initializeModelList(): Promise<ModelInfo[]> {
481
+ let apiKeys: Record<string, string> = {};
482
+
483
+ try {
484
+ const storedApiKeys = Cookies.get('apiKeys');
485
+
486
+ if (storedApiKeys) {
487
+ const parsedKeys = JSON.parse(storedApiKeys);
488
+
489
+ if (typeof parsedKeys === 'object' && parsedKeys !== null) {
490
+ apiKeys = parsedKeys;
491
+ }
492
+ }
493
+ } catch (error: any) {
494
+ console.warn(`Failed to fetch apikeys from cookies:${error?.message}`);
495
+ }
496
  MODEL_LIST = [
497
  ...(
498
  await Promise.all(
499
  PROVIDER_LIST.filter(
500
  (p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
501
+ ).map((p) => p.getDynamicModels(apiKeys)),
502
  )
503
  ).flat(),
504
  ...staticModels,
505
  ];
506
+
507
  return MODEL_LIST;
508
  }
509
 
vite.config.ts CHANGED
@@ -27,7 +27,7 @@ export default defineConfig((config) => {
27
  chrome129IssuePlugin(),
28
  config.mode === 'production' && optimizeCssModules({ apply: 'build' }),
29
  ],
30
- envPrefix:["VITE_","OPENAI_LIKE_API_","OLLAMA_API_BASE_URL","LMSTUDIO_API_BASE_URL"],
31
  css: {
32
  preprocessorOptions: {
33
  scss: {
 
27
  chrome129IssuePlugin(),
28
  config.mode === 'production' && optimizeCssModules({ apply: 'build' }),
29
  ],
30
+ envPrefix: ["VITE_", "OPENAI_LIKE_API_", "OLLAMA_API_BASE_URL", "LMSTUDIO_API_BASE_URL","TOGETHER_API_BASE_URL"],
31
  css: {
32
  preprocessorOptions: {
33
  scss: {