mukaddamzaid commited on
Commit
e87a324
·
1 Parent(s): 547981f

feat: add Cohere Command A model

Browse files
ai/providers.ts CHANGED
@@ -3,7 +3,7 @@ import { openai } from "@ai-sdk/openai";
3
  import { google } from "@ai-sdk/google";
4
  import { groq } from "@ai-sdk/groq";
5
  import { customProvider, wrapLanguageModel, extractReasoningMiddleware } from "ai";
6
-
7
  export interface ModelInfo {
8
  provider: string;
9
  name: string;
@@ -21,14 +21,14 @@ const languageModels = {
21
  // "grok-3": xai("grok-3-latest"),
22
  // "grok-3-mini": xai("grok-3-mini-fast-latest"),
23
  "gpt-4.1-mini": openai("gpt-4.1-mini"),
24
- "gpt-4.1-nano": openai("gpt-4.1-nano"),
25
  "gemini-2-flash": google("gemini-2.0-flash-001"),
26
  "qwen-qwq": wrapLanguageModel(
27
  {
28
  model: groq("qwen-qwq-32b"),
29
  middleware
30
  }
31
- )
 
32
  };
33
 
34
  export const modelDetails: Record<keyof typeof languageModels, ModelInfo> = {
@@ -53,13 +53,6 @@ export const modelDetails: Record<keyof typeof languageModels, ModelInfo> = {
53
  apiVersion: "gpt-4.1-mini",
54
  capabilities: [ "Balance", "Creative", "Vision"]
55
  },
56
- "gpt-4.1-nano": {
57
- provider: "OpenAI",
58
- name: "GPT-4.1 Nano",
59
- description: "Smallest and fastest GPT-4.1 variant designed for efficient rapid responses.",
60
- apiVersion: "gpt-4.1-nano",
61
- capabilities: ["Rapid", "Compact", "Efficient", "Vision"]
62
- },
63
  "gemini-2-flash": {
64
  provider: "Google",
65
  name: "Gemini 2 Flash",
@@ -73,6 +66,13 @@ export const modelDetails: Record<keyof typeof languageModels, ModelInfo> = {
73
  description: "Latest version of Alibaba's Qwen QWQ with strong reasoning and coding capabilities.",
74
  apiVersion: "qwen-qwq",
75
  capabilities: ["Reasoning", "Efficient", "Agentic"]
 
 
 
 
 
 
 
76
  }
77
  };
78
 
 
3
  import { google } from "@ai-sdk/google";
4
  import { groq } from "@ai-sdk/groq";
5
  import { customProvider, wrapLanguageModel, extractReasoningMiddleware } from "ai";
6
+ import { cohere } from "@ai-sdk/cohere";
7
  export interface ModelInfo {
8
  provider: string;
9
  name: string;
 
21
  // "grok-3": xai("grok-3-latest"),
22
  // "grok-3-mini": xai("grok-3-mini-fast-latest"),
23
  "gpt-4.1-mini": openai("gpt-4.1-mini"),
 
24
  "gemini-2-flash": google("gemini-2.0-flash-001"),
25
  "qwen-qwq": wrapLanguageModel(
26
  {
27
  model: groq("qwen-qwq-32b"),
28
  middleware
29
  }
30
+ ),
31
+ "command-a": cohere('command-a-03-2025')
32
  };
33
 
34
  export const modelDetails: Record<keyof typeof languageModels, ModelInfo> = {
 
53
  apiVersion: "gpt-4.1-mini",
54
  capabilities: [ "Balance", "Creative", "Vision"]
55
  },
 
 
 
 
 
 
 
56
  "gemini-2-flash": {
57
  provider: "Google",
58
  name: "Gemini 2 Flash",
 
66
  description: "Latest version of Alibaba's Qwen QWQ with strong reasoning and coding capabilities.",
67
  apiVersion: "qwen-qwq",
68
  capabilities: ["Reasoning", "Efficient", "Agentic"]
69
+ },
70
+ "command-a": {
71
+ provider: "Cohere",
72
+ name: "Command A",
73
+ description: "Latest version of Cohere's Command A with strong reasoning and coding capabilities.",
74
+ apiVersion: "command-a-03-2025",
75
+ capabilities: ["Smart", "Fast", "Reasoning"]
76
  }
77
  };
78
 
app/api/chat/route.ts CHANGED
@@ -1,12 +1,10 @@
1
  import { model, type modelID } from "@/ai/providers";
2
- import { composioTools } from "@/ai/tools";
3
  import { streamText, type UIMessage } from "ai";
4
- import { openai } from '@ai-sdk/openai';
5
  import { appendResponseMessages } from 'ai';
6
  import { saveChat, saveMessages, convertToDBMessages } from '@/lib/chat-store';
7
  import { nanoid } from 'nanoid';
8
  import { db } from '@/lib/db';
9
- import { messages, chats } from '@/lib/db/schema';
10
  import { eq, and } from 'drizzle-orm';
11
 
12
  import { experimental_createMCPClient as createMCPClient, MCPTransport } from 'ai';
@@ -14,7 +12,7 @@ import { Experimental_StdioMCPTransport as StdioMCPTransport } from 'ai/mcp-stdi
14
  import { spawn } from "child_process";
15
 
16
  // Allow streaming responses up to 30 seconds
17
- export const maxDuration = 30;
18
 
19
  interface KeyValuePair {
20
  key: string;
@@ -79,13 +77,13 @@ export async function POST(req: Request) {
79
  // Initialize tools
80
  let tools = {};
81
  const mcpClients: any[] = [];
82
-
83
  // Process each MCP server configuration
84
  for (const mcpServer of mcpServers) {
85
  try {
86
  // Create appropriate transport based on type
87
  let transport: MCPTransport | { type: 'sse', url: string, headers?: Record<string, string> };
88
-
89
  if (mcpServer.type === 'sse') {
90
  // Convert headers array to object for SSE transport
91
  const headers: Record<string, string> = {};
@@ -94,9 +92,9 @@ export async function POST(req: Request) {
94
  if (header.key) headers[header.key] = header.value || '';
95
  });
96
  }
97
-
98
- transport = {
99
- type: 'sse' as const,
100
  url: mcpServer.url,
101
  headers: Object.keys(headers).length > 0 ? headers : undefined
102
  };
@@ -106,7 +104,7 @@ export async function POST(req: Request) {
106
  console.warn("Skipping stdio MCP server due to missing command or args");
107
  continue;
108
  }
109
-
110
  // Convert env array to object for stdio transport
111
  const env: Record<string, string> = {};
112
  if (mcpServer.env && mcpServer.env.length > 0) {
@@ -131,7 +129,7 @@ export async function POST(req: Request) {
131
  console.log("installed python package", packageName);
132
  });
133
  }
134
-
135
  transport = new StdioMCPTransport({
136
  command: mcpServer.command,
137
  args: mcpServer.args,
@@ -141,14 +139,14 @@ export async function POST(req: Request) {
141
  console.warn(`Skipping MCP server with unsupported transport type: ${mcpServer.type}`);
142
  continue;
143
  }
144
-
145
  const mcpClient = await createMCPClient({ transport });
146
  mcpClients.push(mcpClient);
147
-
148
  const mcptools = await mcpClient.tools();
149
-
150
  console.log(`MCP tools from ${mcpServer.type} transport:`, Object.keys(mcptools));
151
-
152
  // Add MCP tools to tools object
153
  tools = { ...tools, ...mcptools };
154
  } catch (error) {
@@ -205,13 +203,13 @@ export async function POST(req: Request) {
205
  messages,
206
  responseMessages: response.messages,
207
  });
208
-
209
  await saveChat({
210
  id,
211
  userId,
212
  messages: allMessages,
213
  });
214
-
215
  const dbMessages = convertToDBMessages(allMessages, id);
216
  await saveMessages({ messages: dbMessages });
217
  // close all mcp clients
 
1
  import { model, type modelID } from "@/ai/providers";
 
2
  import { streamText, type UIMessage } from "ai";
 
3
  import { appendResponseMessages } from 'ai';
4
  import { saveChat, saveMessages, convertToDBMessages } from '@/lib/chat-store';
5
  import { nanoid } from 'nanoid';
6
  import { db } from '@/lib/db';
7
+ import { chats } from '@/lib/db/schema';
8
  import { eq, and } from 'drizzle-orm';
9
 
10
  import { experimental_createMCPClient as createMCPClient, MCPTransport } from 'ai';
 
12
  import { spawn } from "child_process";
13
 
14
  // Allow streaming responses up to 30 seconds
15
+ export const maxDuration = 120;
16
 
17
  interface KeyValuePair {
18
  key: string;
 
77
  // Initialize tools
78
  let tools = {};
79
  const mcpClients: any[] = [];
80
+
81
  // Process each MCP server configuration
82
  for (const mcpServer of mcpServers) {
83
  try {
84
  // Create appropriate transport based on type
85
  let transport: MCPTransport | { type: 'sse', url: string, headers?: Record<string, string> };
86
+
87
  if (mcpServer.type === 'sse') {
88
  // Convert headers array to object for SSE transport
89
  const headers: Record<string, string> = {};
 
92
  if (header.key) headers[header.key] = header.value || '';
93
  });
94
  }
95
+
96
+ transport = {
97
+ type: 'sse' as const,
98
  url: mcpServer.url,
99
  headers: Object.keys(headers).length > 0 ? headers : undefined
100
  };
 
104
  console.warn("Skipping stdio MCP server due to missing command or args");
105
  continue;
106
  }
107
+
108
  // Convert env array to object for stdio transport
109
  const env: Record<string, string> = {};
110
  if (mcpServer.env && mcpServer.env.length > 0) {
 
129
  console.log("installed python package", packageName);
130
  });
131
  }
132
+
133
  transport = new StdioMCPTransport({
134
  command: mcpServer.command,
135
  args: mcpServer.args,
 
139
  console.warn(`Skipping MCP server with unsupported transport type: ${mcpServer.type}`);
140
  continue;
141
  }
142
+
143
  const mcpClient = await createMCPClient({ transport });
144
  mcpClients.push(mcpClient);
145
+
146
  const mcptools = await mcpClient.tools();
147
+
148
  console.log(`MCP tools from ${mcpServer.type} transport:`, Object.keys(mcptools));
149
+
150
  // Add MCP tools to tools object
151
  tools = { ...tools, ...mcptools };
152
  } catch (error) {
 
203
  messages,
204
  responseMessages: response.messages,
205
  });
206
+
207
  await saveChat({
208
  id,
209
  userId,
210
  messages: allMessages,
211
  });
212
+
213
  const dbMessages = convertToDBMessages(allMessages, id);
214
  await saveMessages({ messages: dbMessages });
215
  // close all mcp clients
components/model-picker.tsx CHANGED
@@ -42,6 +42,8 @@ export const ModelPicker = ({ selectedModel, setSelectedModel }: ModelPickerProp
42
  return <Zap className="h-3 w-3 text-red-500" />;
43
  case 'groq':
44
  return <Sparkles className="h-3 w-3 text-blue-500" />;
 
 
45
  default:
46
  return <Info className="h-3 w-3 text-blue-500" />;
47
  }
 
42
  return <Zap className="h-3 w-3 text-red-500" />;
43
  case 'groq':
44
  return <Sparkles className="h-3 w-3 text-blue-500" />;
45
+ case 'cohere':
46
+ return <Sparkles className="h-3 w-3 text-yellow-500" />;
47
  default:
48
  return <Info className="h-3 w-3 text-blue-500" />;
49
  }
package.json CHANGED
@@ -13,6 +13,7 @@
13
  "db:studio": "drizzle-kit studio"
14
  },
15
  "dependencies": {
 
16
  "@ai-sdk/google": "^1.2.12",
17
  "@ai-sdk/groq": "^1.2.8",
18
  "@ai-sdk/openai": "^1.3.16",
 
13
  "db:studio": "drizzle-kit studio"
14
  },
15
  "dependencies": {
16
+ "@ai-sdk/cohere": "^1.2.9",
17
  "@ai-sdk/google": "^1.2.12",
18
  "@ai-sdk/groq": "^1.2.8",
19
  "@ai-sdk/openai": "^1.3.16",
pnpm-lock.yaml CHANGED
@@ -8,6 +8,9 @@ importers:
8
 
9
  .:
10
  dependencies:
 
 
 
11
  '@ai-sdk/google':
12
  specifier: ^1.2.12
13
  version: 1.2.12([email protected])
@@ -195,6 +198,12 @@ importers:
195
 
196
  packages:
197
 
 
 
 
 
 
 
198
  '@ai-sdk/[email protected]':
199
  resolution: {integrity: sha512-A8AYqCmBs9SJFiAOP6AX0YEDHWTDrCaUDiRY2cdMSKjJiEknvwnPrAAKf3idgVqYaM2kS0qWz5v9v4pBzXDx+w==}
200
  engines: {node: '>=18'}
@@ -4492,6 +4501,12 @@ packages:
4492
 
4493
  snapshots:
4494
 
 
 
 
 
 
 
4495
4496
  dependencies:
4497
  '@ai-sdk/provider': 1.1.3
 
8
 
9
  .:
10
  dependencies:
11
+ '@ai-sdk/cohere':
12
+ specifier: ^1.2.9
13
+ version: 1.2.9([email protected])
14
  '@ai-sdk/google':
15
  specifier: ^1.2.12
16
  version: 1.2.12([email protected])
 
198
 
199
  packages:
200
 
201
+ '@ai-sdk/[email protected]':
202
+ resolution: {integrity: sha512-ENbQT2bDt1FN+DOLkoFT9n+cojLv7zdN5GCaDtkaqJUuRa+6lHqNO4tRvB79Jg8DY03mhDermty9EhJW2zHoTA==}
203
+ engines: {node: '>=18'}
204
+ peerDependencies:
205
+ zod: ^3.0.0
206
+
207
  '@ai-sdk/[email protected]':
208
  resolution: {integrity: sha512-A8AYqCmBs9SJFiAOP6AX0YEDHWTDrCaUDiRY2cdMSKjJiEknvwnPrAAKf3idgVqYaM2kS0qWz5v9v4pBzXDx+w==}
209
  engines: {node: '>=18'}
 
4501
 
4502
  snapshots:
4503
 
4504
4505
+ dependencies:
4506
+ '@ai-sdk/provider': 1.1.3
4507
+ '@ai-sdk/provider-utils': 2.2.7([email protected])
4508
+ zod: 3.24.2
4509
+
4510
4511
  dependencies:
4512
  '@ai-sdk/provider': 1.1.3