feat: added more dynamic models, sorted and remove duplicate models (#1206)
Browse files
app/lib/modules/llm/manager.ts
CHANGED
|
@@ -118,12 +118,14 @@ export class LLMManager {
|
|
| 118 |
return dynamicModels;
|
| 119 |
}),
|
| 120 |
);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
// Combine static and dynamic models
|
| 123 |
-
const modelList = [
|
| 124 |
-
|
| 125 |
-
...Array.from(this._providers.values()).flatMap((p) => p.staticModels || []),
|
| 126 |
-
];
|
| 127 |
this._modelList = modelList;
|
| 128 |
|
| 129 |
return modelList;
|
|
@@ -178,8 +180,12 @@ export class LLMManager {
|
|
| 178 |
logger.error(`Error getting dynamic models ${provider.name} :`, err);
|
| 179 |
return [];
|
| 180 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
return
|
| 183 |
}
|
| 184 |
getStaticModelListFromProvider(providerArg: BaseProvider) {
|
| 185 |
const provider = this._providers.get(providerArg.name);
|
|
|
|
| 118 |
return dynamicModels;
|
| 119 |
}),
|
| 120 |
);
|
| 121 |
+
const staticModels = Array.from(this._providers.values()).flatMap((p) => p.staticModels || []);
|
| 122 |
+
const dynamicModelsFlat = dynamicModels.flat();
|
| 123 |
+
const dynamicModelKeys = dynamicModelsFlat.map((d) => `${d.name}-${d.provider}`);
|
| 124 |
+
const filteredStaticModesl = staticModels.filter((m) => !dynamicModelKeys.includes(`${m.name}-${m.provider}`));
|
| 125 |
|
| 126 |
// Combine static and dynamic models
|
| 127 |
+
const modelList = [...dynamicModelsFlat, ...filteredStaticModesl];
|
| 128 |
+
modelList.sort((a, b) => a.name.localeCompare(b.name));
|
|
|
|
|
|
|
| 129 |
this._modelList = modelList;
|
| 130 |
|
| 131 |
return modelList;
|
|
|
|
| 180 |
logger.error(`Error getting dynamic models ${provider.name} :`, err);
|
| 181 |
return [];
|
| 182 |
});
|
| 183 |
+
const dynamicModelsName = dynamicModels.map((d) => d.name);
|
| 184 |
+
const filteredStaticList = staticModels.filter((m) => !dynamicModelsName.includes(m.name));
|
| 185 |
+
const modelList = [...dynamicModels, ...filteredStaticList];
|
| 186 |
+
modelList.sort((a, b) => a.name.localeCompare(b.name));
|
| 187 |
|
| 188 |
+
return modelList;
|
| 189 |
}
|
| 190 |
getStaticModelListFromProvider(providerArg: BaseProvider) {
|
| 191 |
const provider = this._providers.get(providerArg.name);
|
app/lib/modules/llm/providers/google.ts
CHANGED
|
@@ -14,7 +14,12 @@ export default class GoogleProvider extends BaseProvider {
|
|
| 14 |
|
| 15 |
staticModels: ModelInfo[] = [
|
| 16 |
{ name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google', maxTokenAllowed: 8192 },
|
| 17 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
{ name: 'gemini-2.0-flash-exp', label: 'Gemini 2.0 Flash', provider: 'Google', maxTokenAllowed: 8192 },
|
| 19 |
{ name: 'gemini-1.5-flash-002', label: 'Gemini 1.5 Flash-002', provider: 'Google', maxTokenAllowed: 8192 },
|
| 20 |
{ name: 'gemini-1.5-flash-8b', label: 'Gemini 1.5 Flash-8b', provider: 'Google', maxTokenAllowed: 8192 },
|
|
@@ -23,6 +28,41 @@ export default class GoogleProvider extends BaseProvider {
|
|
| 23 |
{ name: 'gemini-exp-1206', label: 'Gemini exp-1206', provider: 'Google', maxTokenAllowed: 8192 },
|
| 24 |
];
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
getModelInstance(options: {
|
| 27 |
model: string;
|
| 28 |
serverEnv: any;
|
|
|
|
| 14 |
|
| 15 |
staticModels: ModelInfo[] = [
|
| 16 |
{ name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google', maxTokenAllowed: 8192 },
|
| 17 |
+
{
|
| 18 |
+
name: 'gemini-2.0-flash-thinking-exp-01-21',
|
| 19 |
+
label: 'Gemini 2.0 Flash-thinking-exp-01-21',
|
| 20 |
+
provider: 'Google',
|
| 21 |
+
maxTokenAllowed: 65536,
|
| 22 |
+
},
|
| 23 |
{ name: 'gemini-2.0-flash-exp', label: 'Gemini 2.0 Flash', provider: 'Google', maxTokenAllowed: 8192 },
|
| 24 |
{ name: 'gemini-1.5-flash-002', label: 'Gemini 1.5 Flash-002', provider: 'Google', maxTokenAllowed: 8192 },
|
| 25 |
{ name: 'gemini-1.5-flash-8b', label: 'Gemini 1.5 Flash-8b', provider: 'Google', maxTokenAllowed: 8192 },
|
|
|
|
| 28 |
{ name: 'gemini-exp-1206', label: 'Gemini exp-1206', provider: 'Google', maxTokenAllowed: 8192 },
|
| 29 |
];
|
| 30 |
|
| 31 |
+
async getDynamicModels(
|
| 32 |
+
apiKeys?: Record<string, string>,
|
| 33 |
+
settings?: IProviderSetting,
|
| 34 |
+
serverEnv?: Record<string, string>,
|
| 35 |
+
): Promise<ModelInfo[]> {
|
| 36 |
+
const { apiKey } = this.getProviderBaseUrlAndKey({
|
| 37 |
+
apiKeys,
|
| 38 |
+
providerSettings: settings,
|
| 39 |
+
serverEnv: serverEnv as any,
|
| 40 |
+
defaultBaseUrlKey: '',
|
| 41 |
+
defaultApiTokenKey: 'GOOGLE_GENERATIVE_AI_API_KEY',
|
| 42 |
+
});
|
| 43 |
+
|
| 44 |
+
if (!apiKey) {
|
| 45 |
+
throw `Missing Api Key configuration for ${this.name} provider`;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
const response = await fetch(`https://generativelanguage.googleapis.com/v1beta/models?key=${apiKey}`, {
|
| 49 |
+
headers: {
|
| 50 |
+
['Content-Type']: 'application/json',
|
| 51 |
+
},
|
| 52 |
+
});
|
| 53 |
+
|
| 54 |
+
const res = (await response.json()) as any;
|
| 55 |
+
|
| 56 |
+
const data = res.models.filter((model: any) => model.outputTokenLimit > 8000);
|
| 57 |
+
|
| 58 |
+
return data.map((m: any) => ({
|
| 59 |
+
name: m.name.replace('models/', ''),
|
| 60 |
+
label: `${m.displayName} - context ${Math.floor((m.inputTokenLimit + m.outputTokenLimit) / 1000) + 'k'}`,
|
| 61 |
+
provider: this.name,
|
| 62 |
+
maxTokenAllowed: m.inputTokenLimit + m.outputTokenLimit || 8000,
|
| 63 |
+
}));
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
getModelInstance(options: {
|
| 67 |
model: string;
|
| 68 |
serverEnv: any;
|
app/lib/modules/llm/providers/groq.ts
CHANGED
|
@@ -19,9 +19,51 @@ export default class GroqProvider extends BaseProvider {
|
|
| 19 |
{ name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
|
| 20 |
{ name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
|
| 21 |
{ name: 'llama-3.3-70b-versatile', label: 'Llama 3.3 70b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
|
| 22 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
];
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
getModelInstance(options: {
|
| 26 |
model: string;
|
| 27 |
serverEnv: Env;
|
|
|
|
| 19 |
{ name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
|
| 20 |
{ name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
|
| 21 |
{ name: 'llama-3.3-70b-versatile', label: 'Llama 3.3 70b (Groq)', provider: 'Groq', maxTokenAllowed: 8000 },
|
| 22 |
+
{
|
| 23 |
+
name: 'deepseek-r1-distill-llama-70b',
|
| 24 |
+
label: 'Deepseek R1 Distill Llama 70b (Groq)',
|
| 25 |
+
provider: 'Groq',
|
| 26 |
+
maxTokenAllowed: 131072,
|
| 27 |
+
},
|
| 28 |
];
|
| 29 |
|
| 30 |
+
async getDynamicModels(
|
| 31 |
+
apiKeys?: Record<string, string>,
|
| 32 |
+
settings?: IProviderSetting,
|
| 33 |
+
serverEnv?: Record<string, string>,
|
| 34 |
+
): Promise<ModelInfo[]> {
|
| 35 |
+
const { apiKey } = this.getProviderBaseUrlAndKey({
|
| 36 |
+
apiKeys,
|
| 37 |
+
providerSettings: settings,
|
| 38 |
+
serverEnv: serverEnv as any,
|
| 39 |
+
defaultBaseUrlKey: '',
|
| 40 |
+
defaultApiTokenKey: 'GROQ_API_KEY',
|
| 41 |
+
});
|
| 42 |
+
|
| 43 |
+
if (!apiKey) {
|
| 44 |
+
throw `Missing Api Key configuration for ${this.name} provider`;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
const response = await fetch(`https://api.groq.com/openai/v1/models`, {
|
| 48 |
+
headers: {
|
| 49 |
+
Authorization: `Bearer ${apiKey}`,
|
| 50 |
+
},
|
| 51 |
+
});
|
| 52 |
+
|
| 53 |
+
const res = (await response.json()) as any;
|
| 54 |
+
|
| 55 |
+
const data = res.data.filter(
|
| 56 |
+
(model: any) => model.object === 'model' && model.active && model.context_window > 8000,
|
| 57 |
+
);
|
| 58 |
+
|
| 59 |
+
return data.map((m: any) => ({
|
| 60 |
+
name: m.id,
|
| 61 |
+
label: `${m.id} - context ${m.context_window ? Math.floor(m.context_window / 1000) + 'k' : 'N/A'} [ by ${m.owned_by}]`,
|
| 62 |
+
provider: this.name,
|
| 63 |
+
maxTokenAllowed: m.context_window || 8000,
|
| 64 |
+
}));
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
getModelInstance(options: {
|
| 68 |
model: string;
|
| 69 |
serverEnv: Env;
|
app/routes/api.models.ts
CHANGED
|
@@ -67,11 +67,11 @@ export async function loader({
|
|
| 67 |
const provider = llmManager.getProvider(params.provider);
|
| 68 |
|
| 69 |
if (provider) {
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
:
|
| 74 |
-
|
| 75 |
}
|
| 76 |
} else {
|
| 77 |
// Update all models
|
|
|
|
| 67 |
const provider = llmManager.getProvider(params.provider);
|
| 68 |
|
| 69 |
if (provider) {
|
| 70 |
+
modelList = await llmManager.getModelListFromProvider(provider, {
|
| 71 |
+
apiKeys,
|
| 72 |
+
providerSettings,
|
| 73 |
+
serverEnv: context.cloudflare?.env,
|
| 74 |
+
});
|
| 75 |
}
|
| 76 |
} else {
|
| 77 |
// Update all models
|