Spaces:
Running
Running
feat: Add Hugging Face model support
Browse files- .env.example +2 -0
- app/api/chat/route.ts +12 -4
- components/chat.tsx +12 -2
- components/message.tsx +1 -1
- components/model-picker.tsx +22 -13
- lib/hf-client.ts +6 -0
- lib/models.ts +22 -3
.env.example
CHANGED
@@ -9,3 +9,5 @@ OPENAI_BASE_URL="" # optional β leave blank to use api.openai.com
|
|
9 |
# When pointing to Azure: OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/deployments/<deployment-name>
|
10 |
# Optional extra headers as JSON, eg: {"api-key":"abc","organization":"org_xyz"}
|
11 |
OPENAI_EXTRA_HEADERS=""
|
|
|
|
|
|
9 |
# When pointing to Azure: OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/deployments/<deployment-name>
|
10 |
# Optional extra headers as JSON, eg: {"api-key":"abc","organization":"org_xyz"}
|
11 |
OPENAI_EXTRA_HEADERS=""
|
12 |
+
|
13 |
+
HF_TOKEN=
|
app/api/chat/route.ts
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import { openai } from "@/lib/openai-client";
|
2 |
-
import
|
|
|
3 |
import { saveChat } from "@/lib/chat-store";
|
4 |
import { nanoid } from "nanoid";
|
5 |
import { db } from "@/lib/db";
|
@@ -105,13 +106,20 @@ export async function POST(req: Request) {
|
|
105 |
|
106 |
const { tools, cleanup } = await initializeMCPClients(mcpServers, req.signal);
|
107 |
|
108 |
-
const
|
|
|
|
|
|
|
|
|
|
|
109 |
{
|
110 |
model: selectedModel,
|
111 |
stream: true,
|
112 |
messages,
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
},
|
116 |
{ signal: req.signal }
|
117 |
);
|
|
|
1 |
import { openai } from "@/lib/openai-client";
|
2 |
+
import { hf } from "@/lib/hf-client";
|
3 |
+
import { getModels, type ModelID } from "@/lib/models";
|
4 |
import { saveChat } from "@/lib/chat-store";
|
5 |
import { nanoid } from "nanoid";
|
6 |
import { db } from "@/lib/db";
|
|
|
106 |
|
107 |
const { tools, cleanup } = await initializeMCPClients(mcpServers, req.signal);
|
108 |
|
109 |
+
const hfModels = await getModels();
|
110 |
+
const client = hfModels.includes(selectedModel) ? hf : openai;
|
111 |
+
|
112 |
+
const openAITools = mcpToolsToOpenAITools(tools);
|
113 |
+
|
114 |
+
const completion = await client.chat.completions.create(
|
115 |
{
|
116 |
model: selectedModel,
|
117 |
stream: true,
|
118 |
messages,
|
119 |
+
...(openAITools.length > 0 && {
|
120 |
+
tools: openAITools,
|
121 |
+
tool_choice: "auto",
|
122 |
+
}),
|
123 |
},
|
124 |
{ signal: req.signal }
|
125 |
);
|
components/chat.tsx
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
"use client";
|
2 |
|
3 |
-
import {
|
4 |
import { Message, useChat } from "@ai-sdk/react";
|
5 |
import { useState, useEffect, useMemo, useCallback } from "react";
|
6 |
import { Textarea } from "./textarea";
|
@@ -30,7 +30,7 @@ export default function Chat() {
|
|
30 |
const chatId = params?.id as string | undefined;
|
31 |
const queryClient = useQueryClient();
|
32 |
|
33 |
-
const [selectedModel, setSelectedModel] = useLocalStorage<ModelID>("selectedModel",
|
34 |
const [userId, setUserId] = useState<string>('');
|
35 |
const [generatedChatId, setGeneratedChatId] = useState<string>('');
|
36 |
|
@@ -41,6 +41,16 @@ export default function Chat() {
|
|
41 |
useEffect(() => {
|
42 |
setUserId(getUserId());
|
43 |
}, []);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
// Generate a chat ID if needed
|
46 |
useEffect(() => {
|
|
|
1 |
"use client";
|
2 |
|
3 |
+
import { getDefaultModel, type ModelID } from "@/lib/models";
|
4 |
import { Message, useChat } from "@ai-sdk/react";
|
5 |
import { useState, useEffect, useMemo, useCallback } from "react";
|
6 |
import { Textarea } from "./textarea";
|
|
|
30 |
const chatId = params?.id as string | undefined;
|
31 |
const queryClient = useQueryClient();
|
32 |
|
33 |
+
const [selectedModel, setSelectedModel] = useLocalStorage<ModelID>("selectedModel", "");
|
34 |
const [userId, setUserId] = useState<string>('');
|
35 |
const [generatedChatId, setGeneratedChatId] = useState<string>('');
|
36 |
|
|
|
41 |
useEffect(() => {
|
42 |
setUserId(getUserId());
|
43 |
}, []);
|
44 |
+
|
45 |
+
useEffect(() => {
|
46 |
+
const fetchDefaultModel = async () => {
|
47 |
+
const defaultModel = await getDefaultModel();
|
48 |
+
if (!selectedModel) {
|
49 |
+
setSelectedModel(defaultModel);
|
50 |
+
}
|
51 |
+
};
|
52 |
+
fetchDefaultModel();
|
53 |
+
}, [selectedModel, setSelectedModel]);
|
54 |
|
55 |
// Generate a chat ID if needed
|
56 |
useEffect(() => {
|
components/message.tsx
CHANGED
@@ -163,7 +163,7 @@ const PurePreviewMessage = ({
|
|
163 |
>
|
164 |
<div
|
165 |
className={cn("flex flex-col gap-3 w-full", {
|
166 |
-
"bg-secondary text-secondary-foreground px-4 py-
|
167 |
message.role === "user",
|
168 |
})}
|
169 |
>
|
|
|
163 |
>
|
164 |
<div
|
165 |
className={cn("flex flex-col gap-3 w-full", {
|
166 |
+
"bg-secondary text-secondary-foreground px-4 py-1.5 rounded-2xl":
|
167 |
message.role === "user",
|
168 |
})}
|
169 |
>
|
components/model-picker.tsx
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"use client";
|
2 |
-
import {
|
3 |
import {
|
4 |
Select,
|
5 |
SelectContent,
|
@@ -10,7 +10,7 @@ import {
|
|
10 |
} from "./ui/select";
|
11 |
import { cn } from "@/lib/utils";
|
12 |
import { Bot } from "lucide-react";
|
13 |
-
import { useEffect } from "react";
|
14 |
|
15 |
interface ModelPickerProps {
|
16 |
selectedModel: ModelID;
|
@@ -18,19 +18,28 @@ interface ModelPickerProps {
|
|
18 |
}
|
19 |
|
20 |
export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProps) => {
|
21 |
-
|
22 |
-
const validModelId =
|
23 |
-
|
24 |
-
// If the selected model is invalid, update it to the default
|
25 |
useEffect(() => {
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
// Handle model change
|
32 |
const handleModelChange = (modelId: string) => {
|
33 |
-
if (
|
34 |
setSelectedModel(modelId as ModelID);
|
35 |
}
|
36 |
};
|
@@ -43,7 +52,7 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
|
|
43 |
defaultValue={validModelId}
|
44 |
>
|
45 |
<SelectTrigger
|
46 |
-
className="max-w-[200px] sm:max-w-fit sm:w-
|
47 |
>
|
48 |
<SelectValue
|
49 |
placeholder="Select model"
|
@@ -60,7 +69,7 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
|
|
60 |
className="bg-background/95 dark:bg-muted/95 backdrop-blur-sm border-border/80 rounded-lg overflow-hidden p-0 w-[280px]"
|
61 |
>
|
62 |
<SelectGroup className="space-y-1 p-1">
|
63 |
-
{
|
64 |
<SelectItem
|
65 |
key={id}
|
66 |
value={id}
|
|
|
1 |
"use client";
|
2 |
+
import { getModels, getDefaultModel, ModelID } from "@/lib/models";
|
3 |
import {
|
4 |
Select,
|
5 |
SelectContent,
|
|
|
10 |
} from "./ui/select";
|
11 |
import { cn } from "@/lib/utils";
|
12 |
import { Bot } from "lucide-react";
|
13 |
+
import { useEffect, useState } from "react";
|
14 |
|
15 |
interface ModelPickerProps {
|
16 |
selectedModel: ModelID;
|
|
|
18 |
}
|
19 |
|
20 |
export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProps) => {
|
21 |
+
const [models, setModels] = useState<ModelID[]>([]);
|
22 |
+
const [validModelId, setValidModelId] = useState<ModelID>("");
|
23 |
+
|
|
|
24 |
useEffect(() => {
|
25 |
+
const fetchModels = async () => {
|
26 |
+
const availableModels = await getModels();
|
27 |
+
setModels(availableModels);
|
28 |
+
const defaultModel = await getDefaultModel();
|
29 |
+
const currentModel = selectedModel || defaultModel;
|
30 |
+
const isValid = availableModels.includes(currentModel);
|
31 |
+
const newValidModelId = isValid ? currentModel : defaultModel;
|
32 |
+
setValidModelId(newValidModelId);
|
33 |
+
if (selectedModel !== newValidModelId) {
|
34 |
+
setSelectedModel(newValidModelId);
|
35 |
+
}
|
36 |
+
};
|
37 |
+
fetchModels();
|
38 |
+
}, [selectedModel, setSelectedModel]);
|
39 |
|
40 |
// Handle model change
|
41 |
const handleModelChange = (modelId: string) => {
|
42 |
+
if (models.includes(modelId as ModelID)) {
|
43 |
setSelectedModel(modelId as ModelID);
|
44 |
}
|
45 |
};
|
|
|
52 |
defaultValue={validModelId}
|
53 |
>
|
54 |
<SelectTrigger
|
55 |
+
className="max-w-[200px] sm:max-w-fit sm:w-80 px-2 sm:px-3 h-8 sm:h-9 rounded-full group border-primary/20 bg-primary/5 hover:bg-primary/10 dark:bg-primary/10 dark:hover:bg-primary/20 transition-all duration-200 ring-offset-background focus:ring-2 focus:ring-primary/30 focus:ring-offset-2"
|
56 |
>
|
57 |
<SelectValue
|
58 |
placeholder="Select model"
|
|
|
69 |
className="bg-background/95 dark:bg-muted/95 backdrop-blur-sm border-border/80 rounded-lg overflow-hidden p-0 w-[280px]"
|
70 |
>
|
71 |
<SelectGroup className="space-y-1 p-1">
|
72 |
+
{models.map((id) => (
|
73 |
<SelectItem
|
74 |
key={id}
|
75 |
value={id}
|
lib/hf-client.ts
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import OpenAI from "openai";
|
2 |
+
|
3 |
+
export const hf = new OpenAI({
|
4 |
+
apiKey: process.env.HF_TOKEN,
|
5 |
+
baseURL: "https://router.huggingface.co/v1",
|
6 |
+
});
|
lib/models.ts
CHANGED
@@ -2,8 +2,27 @@
|
|
2 |
* List here only the model IDs your endpoint exposes.
|
3 |
* Add/remove freely β nothing else in the codebase cares.
|
4 |
*/
|
5 |
-
export
|
6 |
|
7 |
-
|
8 |
|
9 |
-
export
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
* List here only the model IDs your endpoint exposes.
|
3 |
* Add/remove freely β nothing else in the codebase cares.
|
4 |
*/
|
5 |
+
export type ModelID = string;
|
6 |
|
7 |
+
let modelsCache: string[] | null = null;
|
8 |
|
9 |
+
export async function getModels(): Promise<string[]> {
|
10 |
+
if (modelsCache) {
|
11 |
+
return modelsCache;
|
12 |
+
}
|
13 |
+
try {
|
14 |
+
const response = await fetch("https://router.huggingface.co/v1/models");
|
15 |
+
const data = await response.json();
|
16 |
+
const modelIds = data.data.slice(0, 5).map((model: any) => model.id);
|
17 |
+
modelsCache = modelIds;
|
18 |
+
return modelIds;
|
19 |
+
} catch (e) {
|
20 |
+
console.error(e);
|
21 |
+
return [];
|
22 |
+
}
|
23 |
+
}
|
24 |
+
|
25 |
+
export async function getDefaultModel(): Promise<ModelID> {
|
26 |
+
const models = await getModels();
|
27 |
+
return models[0] ?? "";
|
28 |
+
}
|