victor HF Staff commited on
Commit
8cfdcec
Β·
1 Parent(s): 2c7a77a

feat: Add Hugging Face model support

Browse files
.env.example CHANGED
@@ -9,3 +9,5 @@ OPENAI_BASE_URL="" # optional – leave blank to use api.openai.com
9
  # When pointing to Azure: OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/deployments/<deployment-name>
10
  # Optional extra headers as JSON, eg: {"api-key":"abc","organization":"org_xyz"}
11
  OPENAI_EXTRA_HEADERS=""
 
 
 
9
  # When pointing to Azure: OPENAI_BASE_URL=https://<resource>.openai.azure.com/openai/deployments/<deployment-name>
10
  # Optional extra headers as JSON, eg: {"api-key":"abc","organization":"org_xyz"}
11
  OPENAI_EXTRA_HEADERS=""
12
+
13
+ HF_TOKEN=
app/api/chat/route.ts CHANGED
@@ -1,5 +1,6 @@
1
  import { openai } from "@/lib/openai-client";
2
- import type { ModelID } from "@/lib/models";
 
3
  import { saveChat } from "@/lib/chat-store";
4
  import { nanoid } from "nanoid";
5
  import { db } from "@/lib/db";
@@ -105,13 +106,20 @@ export async function POST(req: Request) {
105
 
106
  const { tools, cleanup } = await initializeMCPClients(mcpServers, req.signal);
107
 
108
- const completion = await openai.chat.completions.create(
 
 
 
 
 
109
  {
110
  model: selectedModel,
111
  stream: true,
112
  messages,
113
- tools: mcpToolsToOpenAITools(tools),
114
- tool_choice: "auto",
 
 
115
  },
116
  { signal: req.signal }
117
  );
 
1
  import { openai } from "@/lib/openai-client";
2
+ import { hf } from "@/lib/hf-client";
3
+ import { getModels, type ModelID } from "@/lib/models";
4
  import { saveChat } from "@/lib/chat-store";
5
  import { nanoid } from "nanoid";
6
  import { db } from "@/lib/db";
 
106
 
107
  const { tools, cleanup } = await initializeMCPClients(mcpServers, req.signal);
108
 
109
+ const hfModels = await getModels();
110
+ const client = hfModels.includes(selectedModel) ? hf : openai;
111
+
112
+ const openAITools = mcpToolsToOpenAITools(tools);
113
+
114
+ const completion = await client.chat.completions.create(
115
  {
116
  model: selectedModel,
117
  stream: true,
118
  messages,
119
+ ...(openAITools.length > 0 && {
120
+ tools: openAITools,
121
+ tool_choice: "auto",
122
+ }),
123
  },
124
  { signal: req.signal }
125
  );
components/chat.tsx CHANGED
@@ -1,6 +1,6 @@
1
  "use client";
2
 
3
- import { DEFAULT_MODEL, type ModelID } from "@/lib/models";
4
  import { Message, useChat } from "@ai-sdk/react";
5
  import { useState, useEffect, useMemo, useCallback } from "react";
6
  import { Textarea } from "./textarea";
@@ -30,7 +30,7 @@ export default function Chat() {
30
  const chatId = params?.id as string | undefined;
31
  const queryClient = useQueryClient();
32
 
33
- const [selectedModel, setSelectedModel] = useLocalStorage<ModelID>("selectedModel", DEFAULT_MODEL);
34
  const [userId, setUserId] = useState<string>('');
35
  const [generatedChatId, setGeneratedChatId] = useState<string>('');
36
 
@@ -41,6 +41,16 @@ export default function Chat() {
41
  useEffect(() => {
42
  setUserId(getUserId());
43
  }, []);
 
 
 
 
 
 
 
 
 
 
44
 
45
  // Generate a chat ID if needed
46
  useEffect(() => {
 
1
  "use client";
2
 
3
+ import { getDefaultModel, type ModelID } from "@/lib/models";
4
  import { Message, useChat } from "@ai-sdk/react";
5
  import { useState, useEffect, useMemo, useCallback } from "react";
6
  import { Textarea } from "./textarea";
 
30
  const chatId = params?.id as string | undefined;
31
  const queryClient = useQueryClient();
32
 
33
+ const [selectedModel, setSelectedModel] = useLocalStorage<ModelID>("selectedModel", "");
34
  const [userId, setUserId] = useState<string>('');
35
  const [generatedChatId, setGeneratedChatId] = useState<string>('');
36
 
 
41
  useEffect(() => {
42
  setUserId(getUserId());
43
  }, []);
44
+
45
+ useEffect(() => {
46
+ const fetchDefaultModel = async () => {
47
+ const defaultModel = await getDefaultModel();
48
+ if (!selectedModel) {
49
+ setSelectedModel(defaultModel);
50
+ }
51
+ };
52
+ fetchDefaultModel();
53
+ }, [selectedModel, setSelectedModel]);
54
 
55
  // Generate a chat ID if needed
56
  useEffect(() => {
components/message.tsx CHANGED
@@ -163,7 +163,7 @@ const PurePreviewMessage = ({
163
  >
164
  <div
165
  className={cn("flex flex-col gap-3 w-full", {
166
- "bg-secondary text-secondary-foreground px-4 py-3 rounded-2xl":
167
  message.role === "user",
168
  })}
169
  >
 
163
  >
164
  <div
165
  className={cn("flex flex-col gap-3 w-full", {
166
+ "bg-secondary text-secondary-foreground px-4 py-1.5 rounded-2xl":
167
  message.role === "user",
168
  })}
169
  >
components/model-picker.tsx CHANGED
@@ -1,5 +1,5 @@
1
  "use client";
2
- import { MODELS, DEFAULT_MODEL, ModelID } from "@/lib/models";
3
  import {
4
  Select,
5
  SelectContent,
@@ -10,7 +10,7 @@ import {
10
  } from "./ui/select";
11
  import { cn } from "@/lib/utils";
12
  import { Bot } from "lucide-react";
13
- import { useEffect } from "react";
14
 
15
  interface ModelPickerProps {
16
  selectedModel: ModelID;
@@ -18,19 +18,28 @@ interface ModelPickerProps {
18
  }
19
 
20
  export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProps) => {
21
- // Ensure we always have a valid model ID
22
- const validModelId = MODELS.includes(selectedModel) ? selectedModel : DEFAULT_MODEL;
23
-
24
- // If the selected model is invalid, update it to the default
25
  useEffect(() => {
26
- if (selectedModel !== validModelId) {
27
- setSelectedModel(validModelId);
28
- }
29
- }, [selectedModel, validModelId, setSelectedModel]);
 
 
 
 
 
 
 
 
 
 
30
 
31
  // Handle model change
32
  const handleModelChange = (modelId: string) => {
33
- if (MODELS.includes(modelId as ModelID)) {
34
  setSelectedModel(modelId as ModelID);
35
  }
36
  };
@@ -43,7 +52,7 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
43
  defaultValue={validModelId}
44
  >
45
  <SelectTrigger
46
- className="max-w-[200px] sm:max-w-fit sm:w-56 px-2 sm:px-3 h-8 sm:h-9 rounded-full group border-primary/20 bg-primary/5 hover:bg-primary/10 dark:bg-primary/10 dark:hover:bg-primary/20 transition-all duration-200 ring-offset-background focus:ring-2 focus:ring-primary/30 focus:ring-offset-2"
47
  >
48
  <SelectValue
49
  placeholder="Select model"
@@ -60,7 +69,7 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
60
  className="bg-background/95 dark:bg-muted/95 backdrop-blur-sm border-border/80 rounded-lg overflow-hidden p-0 w-[280px]"
61
  >
62
  <SelectGroup className="space-y-1 p-1">
63
- {MODELS.map((id) => (
64
  <SelectItem
65
  key={id}
66
  value={id}
 
1
  "use client";
2
+ import { getModels, getDefaultModel, ModelID } from "@/lib/models";
3
  import {
4
  Select,
5
  SelectContent,
 
10
  } from "./ui/select";
11
  import { cn } from "@/lib/utils";
12
  import { Bot } from "lucide-react";
13
+ import { useEffect, useState } from "react";
14
 
15
  interface ModelPickerProps {
16
  selectedModel: ModelID;
 
18
  }
19
 
20
  export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProps) => {
21
+ const [models, setModels] = useState<ModelID[]>([]);
22
+ const [validModelId, setValidModelId] = useState<ModelID>("");
23
+
 
24
  useEffect(() => {
25
+ const fetchModels = async () => {
26
+ const availableModels = await getModels();
27
+ setModels(availableModels);
28
+ const defaultModel = await getDefaultModel();
29
+ const currentModel = selectedModel || defaultModel;
30
+ const isValid = availableModels.includes(currentModel);
31
+ const newValidModelId = isValid ? currentModel : defaultModel;
32
+ setValidModelId(newValidModelId);
33
+ if (selectedModel !== newValidModelId) {
34
+ setSelectedModel(newValidModelId);
35
+ }
36
+ };
37
+ fetchModels();
38
+ }, [selectedModel, setSelectedModel]);
39
 
40
  // Handle model change
41
  const handleModelChange = (modelId: string) => {
42
+ if (models.includes(modelId as ModelID)) {
43
  setSelectedModel(modelId as ModelID);
44
  }
45
  };
 
52
  defaultValue={validModelId}
53
  >
54
  <SelectTrigger
55
+ className="max-w-[200px] sm:max-w-fit sm:w-80 px-2 sm:px-3 h-8 sm:h-9 rounded-full group border-primary/20 bg-primary/5 hover:bg-primary/10 dark:bg-primary/10 dark:hover:bg-primary/20 transition-all duration-200 ring-offset-background focus:ring-2 focus:ring-primary/30 focus:ring-offset-2"
56
  >
57
  <SelectValue
58
  placeholder="Select model"
 
69
  className="bg-background/95 dark:bg-muted/95 backdrop-blur-sm border-border/80 rounded-lg overflow-hidden p-0 w-[280px]"
70
  >
71
  <SelectGroup className="space-y-1 p-1">
72
+ {models.map((id) => (
73
  <SelectItem
74
  key={id}
75
  value={id}
lib/hf-client.ts ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import OpenAI from "openai";
2
+
3
+ export const hf = new OpenAI({
4
+ apiKey: process.env.HF_TOKEN,
5
+ baseURL: "https://router.huggingface.co/v1",
6
+ });
lib/models.ts CHANGED
@@ -2,8 +2,27 @@
2
  * List here only the model IDs your endpoint exposes.
3
  * Add/remove freely – nothing else in the codebase cares.
4
  */
5
- export const MODELS = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"] as const;
6
 
7
- export type ModelID = (typeof MODELS)[number];
8
 
9
- export const DEFAULT_MODEL: ModelID = "gpt-4o-mini";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  * List here only the model IDs your endpoint exposes.
3
  * Add/remove freely – nothing else in the codebase cares.
4
  */
5
+ export type ModelID = string;
6
 
7
+ let modelsCache: string[] | null = null;
8
 
9
+ export async function getModels(): Promise<string[]> {
10
+ if (modelsCache) {
11
+ return modelsCache;
12
+ }
13
+ try {
14
+ const response = await fetch("https://router.huggingface.co/v1/models");
15
+ const data = await response.json();
16
+ const modelIds = data.data.slice(0, 5).map((model: any) => model.id);
17
+ modelsCache = modelIds;
18
+ return modelIds;
19
+ } catch (e) {
20
+ console.error(e);
21
+ return [];
22
+ }
23
+ }
24
+
25
+ export async function getDefaultModel(): Promise<ModelID> {
26
+ const models = await getModels();
27
+ return models[0] ?? "";
28
+ }