codacus commited on
Commit
32bfdd9
·
unverified ·
1 Parent(s): 39a0724

feat: added more dynamic models, sorted and remove duplicate models (#1206)

Browse files
app/lib/modules/llm/manager.ts CHANGED
@@ -118,12 +118,14 @@ export class LLMManager {
118
  return dynamicModels;
119
  }),
120
  );
 
 
 
 
121
 
122
  // Combine static and dynamic models
123
- const modelList = [
124
- ...dynamicModels.flat(),
125
- ...Array.from(this._providers.values()).flatMap((p) => p.staticModels || []),
126
- ];
127
  this._modelList = modelList;
128
 
129
  return modelList;
@@ -178,8 +180,12 @@ export class LLMManager {
178
  logger.error(`Error getting dynamic models ${provider.name} :`, err);
179
  return [];
180
  });
 
 
 
 
181
 
182
- return [...dynamicModels, ...staticModels];
183
  }
184
  getStaticModelListFromProvider(providerArg: BaseProvider) {
185
  const provider = this._providers.get(providerArg.name);
 
118
  return dynamicModels;
119
  }),
120
  );
121
+ const staticModels = Array.from(this._providers.values()).flatMap((p) => p.staticModels || []);
122
+ const dynamicModelsFlat = dynamicModels.flat();
123
+ const dynamicModelKeys = dynamicModelsFlat.map((d) => `${d.name}-${d.provider}`);
124
+ const filteredStaticModesl = staticModels.filter((m) => !dynamicModelKeys.includes(`${m.name}-${m.provider}`));
125
 
126
  // Combine static and dynamic models
127
+ const modelList = [...dynamicModelsFlat, ...filteredStaticModesl];
128
+ modelList.sort((a, b) => a.name.localeCompare(b.name));
 
 
129
  this._modelList = modelList;
130
 
131
  return modelList;
 
180
  logger.error(`Error getting dynamic models ${provider.name} :`, err);
181
  return [];
182
  });
183
+ const dynamicModelsName = dynamicModels.map((d) => d.name);
184
+ const filteredStaticList = staticModels.filter((m) => !dynamicModelsName.includes(m.name));
185
+ const modelList = [...dynamicModels, ...filteredStaticList];
186
+ modelList.sort((a, b) => a.name.localeCompare(b.name));
187
 
188
+ return modelList;
189
  }
190
  getStaticModelListFromProvider(providerArg: BaseProvider) {
191
  const provider = this._providers.get(providerArg.name);
app/lib/modules/llm/providers/google.ts CHANGED
@@ -14,7 +14,12 @@ export default class GoogleProvider extends BaseProvider {
14
 
15
  staticModels: ModelInfo[] = [
16
  { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google', maxTokenAllowed: 8192 },
17
- { name: 'gemini-2.0-flash-thinking-exp-01-21', label: 'Gemini 2.0 Flash-thinking-exp-01-21', provider: 'Google', maxTokenAllowed: 65536 },
 
 
 
 
 
18
  { name: 'gemini-2.0-flash-exp', label: 'Gemini 2.0 Flash', provider: 'Google', maxTokenAllowed: 8192 },
19
  { name: 'gemini-1.5-flash-002', label: 'Gemini 1.5 Flash-002', provider: 'Google', maxTokenAllowed: 8192 },
20
  { name: 'gemini-1.5-flash-8b', label: 'Gemini 1.5 Flash-8b', provider: 'Google', maxTokenAllowed: 8192 },
@@ -23,6 +28,41 @@ export default class GoogleProvider extends BaseProvider {
23
  { name: 'gemini-exp-1206', label: 'Gemini exp-1206', provider: 'Google', maxTokenAllowed: 8192 },
24
  ];
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  getModelInstance(options: {
27
  model: string;
28
  serverEnv: any;
 
14
 
15
  staticModels: ModelInfo[] = [
16
  { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google', maxTokenAllowed: 8192 },
17
+ {
18
+ name: 'gemini-2.0-flash-thinking-exp-01-21',
19
+ label: 'Gemini 2.0 Flash-thinking-exp-01-21',
20
+ provider: 'Google',
21
+ maxTokenAllowed: 65536,
22
+ },
23
  { name: 'gemini-2.0-flash-exp', label: 'Gemini 2.0 Flash', provider: 'Google', maxTokenAllowed: 8192 },
24
  { name: 'gemini-1.5-flash-002', label: 'Gemini 1.5 Flash-002', provider: 'Google', maxTokenAllowed: 8192 },
25
  { name: 'gemini-1.5-flash-8b', label: 'Gemini 1.5 Flash-8b', provider: 'Google', maxTokenAllowed: 8192 },
 
28
  { name: 'gemini-exp-1206', label: 'Gemini exp-1206', provider: 'Google', maxTokenAllowed: 8192 },
29
  ];
30
 
31
+ async getDynamicModels(
32
+ apiKeys?: Record<string, string>,
33
+ settings?: IProviderSetting,
34
+ serverEnv?: Record<string, string>,
35
+ ): Promise<ModelInfo[]> {
36
+ const { apiKey } = this.getProviderBaseUrlAndKey({
37
+ apiKeys,
38
+ providerSettings: settings,
39
+ serverEnv: serverEnv as any,
40
+ defaultBaseUrlKey: '',
41
+ defaultApiTokenKey: 'GOOGLE_GENERATIVE_AI_API_KEY',
42
+ });
43
+
44
+ if (!apiKey) {
45
+ throw `Missing Api Key configuration for ${this.name} provider`;
46
+ }
47
+
48
+ const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`, {
49
+ headers: {
50
+ ['Content-Type']: 'application/json',
51
+ },
52
+ });
53
+
54
+ const res = (await response.json()) as any;
55
+
56
+ const data = res.models.filter((model: any) => model.outputTokenLimit > 8000);
57
+
58
+ return data.map((m: any) => ({
59
+ name: m.name.replace('models/', ''),
60
+ label: `${m.displayName} - context ${Math.floor((m.inputTokenLimit + m.outputTokenLimit) / 1000) + 'k'}`,
61
+ provider: this.name,
62
+ maxTokenAllowed: m.inputTokenLimit + m.outputTokenLimit || 8000,
63
+ }));
64
+ }
65
+
66
  getModelInstance(options: {
67
  model: string;
68
  serverEnv: any;
app/lib/modules/llm/providers/groq.ts CHANGED
@@ -19,9 +19,51 @@ export default class GroqProvider extends BaseProvider {
19
  { name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
20
  { name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
21
  { name: 'llama-3.3-70b-versatile', label: 'Llama 3.3 70b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
22
- { name: 'deepseek-r1-distill-llama-70b', label: 'Deepseek R1 Distill Llama 70b (Groq)', provider: 'Groq', maxTokenAllowed: 131072 },
 
 
 
 
 
23
  ];
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  getModelInstance(options: {
26
  model: string;
27
  serverEnv: Env;
 
19
  { name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
20
  { name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
21
  { name: 'llama-3.3-70b-versatile', label: 'Llama 3.3 70b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
22
+ {
23
+ name: 'deepseek-r1-distill-llama-70b',
24
+ label: 'Deepseek R1 Distill Llama 70b (Groq)',
25
+ provider: 'Groq',
26
+ maxTokenAllowed: 131072,
27
+ },
28
  ];
29
 
30
+ async getDynamicModels(
31
+ apiKeys?: Record<string, string>,
32
+ settings?: IProviderSetting,
33
+ serverEnv?: Record<string, string>,
34
+ ): Promise<ModelInfo[]> {
35
+ const { apiKey } = this.getProviderBaseUrlAndKey({
36
+ apiKeys,
37
+ providerSettings: settings,
38
+ serverEnv: serverEnv as any,
39
+ defaultBaseUrlKey: '',
40
+ defaultApiTokenKey: 'GROQ_API_KEY',
41
+ });
42
+
43
+ if (!apiKey) {
44
+ throw `Missing Api Key configuration for ${this.name} provider`;
45
+ }
46
+
47
+ const response = await fetch(`https://api.groq.com/openai/v1/models`, {
48
+ headers: {
49
+ Authorization: `Bearer ${apiKey}`,
50
+ },
51
+ });
52
+
53
+ const res = (await response.json()) as any;
54
+
55
+ const data = res.data.filter(
56
+ (model: any) => model.object === 'model' && model.active && model.context_window > 8000,
57
+ );
58
+
59
+ return data.map((m: any) => ({
60
+ name: m.id,
61
+ label: `${m.id} - context ${m.context_window ? Math.floor(m.context_window / 1000) + 'k' : 'N/A'} [ by ${m.owned_by}]`,
62
+ provider: this.name,
63
+ maxTokenAllowed: m.context_window || 8000,
64
+ }));
65
+ }
66
+
67
  getModelInstance(options: {
68
  model: string;
69
  serverEnv: Env;
app/routes/api.models.ts CHANGED
@@ -67,11 +67,11 @@ export async function loader({
67
  const provider = llmManager.getProvider(params.provider);
68
 
69
  if (provider) {
70
- const staticModels = provider.staticModels;
71
- const dynamicModels = provider.getDynamicModels
72
- ? await provider.getDynamicModels(apiKeys, providerSettings, context.cloudflare?.env)
73
- : [];
74
- modelList = [...staticModels, ...dynamicModels];
75
  }
76
  } else {
77
  // Update all models
 
67
  const provider = llmManager.getProvider(params.provider);
68
 
69
  if (provider) {
70
+ modelList = await llmManager.getModelListFromProvider(provider, {
71
+ apiKeys,
72
+ providerSettings,
73
+ serverEnv: context.cloudflare?.env,
74
+ });
75
  }
76
  } else {
77
  // Update all models