Spaces:
Running
Running
import { type Response as ExpressResponse } from "express"; | |
import { type ValidatedRequest } from "../middleware/validation.js"; | |
import type { CreateResponseParams, McpServerParams, McpApprovalRequestParams } from "../schemas.js"; | |
import { generateUniqueId } from "../lib/generateUniqueId.js"; | |
import { InferenceClient } from "@huggingface/inference"; | |
import type { | |
ChatCompletionInputMessage, | |
ChatCompletionInputMessageChunkType, | |
ChatCompletionInput, | |
} from "@huggingface/tasks"; | |
import type { | |
Response, | |
ResponseStreamEvent, | |
ResponseContentPartAddedEvent, | |
ResponseOutputMessage, | |
ResponseFunctionToolCall, | |
ResponseOutputItem, | |
} from "openai/resources/responses/responses"; | |
import type { | |
ChatCompletionInputFunctionDefinition, | |
ChatCompletionInputTool, | |
} from "@huggingface/tasks/dist/commonjs/tasks/chat-completion/inference.js"; | |
import { callMcpTool, connectMcpServer } from "../mcp.js"; | |
class StreamingError extends Error { | |
constructor(message: string) { | |
super(message); | |
this.name = "StreamingError"; | |
} | |
} | |
type IncompleteResponse = Omit<Response, "incomplete_details" | "output_text" | "parallel_tool_calls">; | |
const SEQUENCE_NUMBER_PLACEHOLDER = -1; | |
export const postCreateResponse = async ( | |
req: ValidatedRequest<CreateResponseParams>, | |
res: ExpressResponse | |
): Promise<void> => { | |
// To avoid duplicated code, we run all requests as stream. | |
const events = runCreateResponseStream(req, res); | |
// Then we return in the correct format depending on the user 'stream' flag. | |
if (req.body.stream) { | |
res.setHeader("Content-Type", "text/event-stream"); | |
res.setHeader("Connection", "keep-alive"); | |
console.debug("Stream request"); | |
for await (const event of events) { | |
console.debug(`Event #${event.sequence_number}: ${event.type}`); | |
res.write(`data: ${JSON.stringify(event)}\n\n`); | |
} | |
res.end(); | |
} else { | |
console.debug("Non-stream request"); | |
for await (const event of events) { | |
if (event.type === "response.completed" || event.type === "response.failed") { | |
console.debug(event.type); | |
res.json(event.response); | |
} | |
} | |
} | |
}; | |
/* | |
* Top-level stream. | |
* | |
* Handles response lifecycle + execute inner logic (MCP list tools, MCP tool calls, LLM call, etc.). | |
* Handles sequenceNumber by overwriting it in the events. | |
*/ | |
async function* runCreateResponseStream( | |
req: ValidatedRequest<CreateResponseParams>, | |
res: ExpressResponse | |
): AsyncGenerator<ResponseStreamEvent> { | |
let sequenceNumber = 0; | |
// Prepare response object that will be iteratively populated | |
const responseObject: IncompleteResponse = { | |
created_at: Math.floor(new Date().getTime() / 1000), | |
error: null, | |
id: generateUniqueId("resp"), | |
instructions: req.body.instructions, | |
max_output_tokens: req.body.max_output_tokens, | |
metadata: req.body.metadata, | |
model: req.body.model, | |
object: "response", | |
output: [], | |
// parallel_tool_calls: req.body.parallel_tool_calls, | |
status: "in_progress", | |
text: req.body.text, | |
tool_choice: req.body.tool_choice ?? "auto", | |
tools: req.body.tools ?? [], | |
temperature: req.body.temperature, | |
top_p: req.body.top_p, | |
usage: { | |
input_tokens: 0, | |
input_tokens_details: { cached_tokens: 0 }, | |
output_tokens: 0, | |
output_tokens_details: { reasoning_tokens: 0 }, | |
total_tokens: 0, | |
}, | |
}; | |
// Response created event | |
yield { | |
type: "response.created", | |
response: responseObject as Response, | |
sequence_number: sequenceNumber++, | |
}; | |
// Response in progress event | |
yield { | |
type: "response.in_progress", | |
response: responseObject as Response, | |
sequence_number: sequenceNumber++, | |
}; | |
// Any events (LLM call, MCP call, list tools, etc.) | |
try { | |
for await (const event of innerRunStream(req, res, responseObject)) { | |
yield { ...event, sequence_number: sequenceNumber++ }; | |
} | |
} catch (error) { | |
// Error event => stop | |
console.error("Error in stream:", error); | |
const message = | |
typeof error === "object" && error && "message" in error && typeof error.message === "string" | |
? error.message | |
: "An error occurred in stream"; | |
responseObject.status = "failed"; | |
responseObject.error = { | |
code: "server_error", | |
message, | |
}; | |
yield { | |
type: "response.failed", | |
response: responseObject as Response, | |
sequence_number: sequenceNumber++, | |
}; | |
return; | |
} | |
// Response completed event | |
responseObject.status = "completed"; | |
yield { | |
type: "response.completed", | |
response: responseObject as Response, | |
sequence_number: sequenceNumber++, | |
}; | |
} | |
async function* innerRunStream( | |
req: ValidatedRequest<CreateResponseParams>, | |
res: ExpressResponse, | |
responseObject: IncompleteResponse | |
): AsyncGenerator<ResponseStreamEvent> { | |
// Retrieve API key from headers | |
const apiKey = req.headers.authorization?.split(" ")[1]; | |
if (!apiKey) { | |
res.status(401).json({ | |
success: false, | |
error: "Unauthorized", | |
}); | |
return; | |
} | |
// List MCP tools from server (if required) + prepare tools for the LLM | |
let tools: ChatCompletionInputTool[] | undefined = []; | |
const mcpToolsMapping: Record<string, McpServerParams> = {}; | |
if (req.body.tools) { | |
for (const tool of req.body.tools) { | |
switch (tool.type) { | |
case "function": | |
tools?.push({ | |
type: tool.type, | |
function: { | |
name: tool.name, | |
parameters: tool.parameters, | |
description: tool.description, | |
strict: tool.strict, | |
}, | |
}); | |
break; | |
case "mcp": { | |
let mcpListTools: ResponseOutputItem.McpListTools | undefined; | |
// If MCP list tools is already in the input, use it | |
if (Array.isArray(req.body.input)) { | |
for (const item of req.body.input) { | |
if (item.type === "mcp_list_tools" && item.server_label === tool.server_label) { | |
mcpListTools = item; | |
console.debug(`Using MCP list tools from input for server '${tool.server_label}'`); | |
break; | |
} | |
} | |
} | |
// Otherwise, list tools from MCP server | |
if (!mcpListTools) { | |
for await (const event of listMcpToolsStream(tool, responseObject)) { | |
yield event; | |
} | |
mcpListTools = responseObject.output.at(-1) as ResponseOutputItem.McpListTools; | |
} | |
// Only allowed tools are forwarded to the LLM | |
const allowedTools = tool.allowed_tools | |
? Array.isArray(tool.allowed_tools) | |
? tool.allowed_tools | |
: tool.allowed_tools.tool_names | |
: []; | |
if (mcpListTools?.tools) { | |
for (const mcpTool of mcpListTools.tools) { | |
if (allowedTools.length === 0 || allowedTools.includes(mcpTool.name)) { | |
tools?.push({ | |
type: "function" as const, | |
function: { | |
name: mcpTool.name, | |
parameters: mcpTool.input_schema, | |
description: mcpTool.description ?? undefined, | |
}, | |
}); | |
} | |
mcpToolsMapping[mcpTool.name] = tool; | |
} | |
break; | |
} | |
} | |
} | |
} | |
} | |
if (tools.length === 0) { | |
tools = undefined; | |
} | |
// Prepare payload for the LLM | |
// Resolve model and provider | |
const model = req.body.model.includes("@") ? req.body.model.split("@")[1] : req.body.model; | |
const provider = req.body.model.includes("@") ? req.body.model.split("@")[0] : undefined; | |
// Format input to Chat Completion format | |
const messages: ChatCompletionInputMessage[] = req.body.instructions | |
? [{ role: "system", content: req.body.instructions }] | |
: []; | |
if (Array.isArray(req.body.input)) { | |
messages.push( | |
...req.body.input | |
.map((item) => { | |
switch (item.type) { | |
case "function_call": | |
return { | |
// hacky but best fit for now | |
role: "assistant", | |
name: `function_call ${item.name} ${item.call_id}`, | |
content: item.arguments, | |
}; | |
case "function_call_output": | |
return { | |
// hacky but best fit for now | |
role: "assistant", | |
name: `function_call_output ${item.call_id}`, | |
content: item.output, | |
}; | |
case "message": | |
return { | |
role: item.role, | |
content: | |
typeof item.content === "string" | |
? item.content | |
: item.content | |
.map((content) => { | |
switch (content.type) { | |
case "input_image": | |
return { | |
type: "image_url" as ChatCompletionInputMessageChunkType, | |
image_url: { | |
url: content.image_url, | |
}, | |
}; | |
case "output_text": | |
return content.text | |
? { | |
type: "text" as ChatCompletionInputMessageChunkType, | |
text: content.text, | |
} | |
: undefined; | |
case "refusal": | |
return undefined; | |
case "input_text": | |
return { | |
type: "text" as ChatCompletionInputMessageChunkType, | |
text: content.text, | |
}; | |
} | |
}) | |
.filter((item) => item !== undefined), | |
}; | |
case "mcp_list_tools": { | |
// Hacky: will be dropped by filter since tools are passed as separate objects | |
return { | |
role: "assistant", | |
name: "mcp_list_tools", | |
content: "", | |
}; | |
} | |
case "mcp_call": { | |
return { | |
role: "assistant", | |
name: "mcp_call", | |
content: `MCP call (${item.id}). Server: '${item.server_label}'. Tool: '${item.name}'. Arguments: '${item.arguments}'.`, | |
}; | |
} | |
case "mcp_approval_request": { | |
return { | |
role: "assistant", | |
name: "mcp_approval_request", | |
content: `MCP approval request (${item.id}). Server: '${item.server_label}'. Tool: '${item.name}'. Arguments: '${item.arguments}'.`, | |
}; | |
} | |
case "mcp_approval_response": { | |
return { | |
role: "assistant", | |
name: "mcp_approval_response", | |
content: `MCP approval response (${item.id}). Approved: ${item.approve}. Reason: ${item.reason}.`, | |
}; | |
} | |
} | |
}) | |
.filter((message) => message.content?.length !== 0) | |
); | |
} else { | |
messages.push({ role: "user", content: req.body.input }); | |
} | |
// Prepare payload for the LLM | |
const payload: ChatCompletionInput = { | |
// main params | |
model, | |
provider, | |
messages, | |
stream: req.body.stream, | |
// options | |
max_tokens: req.body.max_output_tokens === null ? undefined : req.body.max_output_tokens, | |
response_format: req.body.text?.format | |
? { | |
type: req.body.text.format.type, | |
json_schema: | |
req.body.text.format.type === "json_schema" | |
? { | |
description: req.body.text.format.description, | |
name: req.body.text.format.name, | |
schema: req.body.text.format.schema, | |
strict: req.body.text.format.strict, | |
} | |
: undefined, | |
} | |
: undefined, | |
temperature: req.body.temperature, | |
tool_choice: | |
typeof req.body.tool_choice === "string" | |
? req.body.tool_choice | |
: req.body.tool_choice | |
? { | |
type: "function", | |
function: { | |
name: req.body.tool_choice.name, | |
}, | |
} | |
: undefined, | |
tools, | |
top_p: req.body.top_p, | |
}; | |
// If MCP approval requests => execute them and return (no LLM call) | |
if (Array.isArray(req.body.input)) { | |
for (const item of req.body.input) { | |
if (item.type === "mcp_approval_response" && item.approve) { | |
const approvalRequest = req.body.input.find( | |
(i) => i.type === "mcp_approval_request" && i.id === item.approval_request_id | |
) as McpApprovalRequestParams | undefined; | |
const mcpCallId = "mcp_" + item.approval_request_id.split("_")[1]; | |
const mcpCall = req.body.input.find((i) => i.type === "mcp_call" && i.id === mcpCallId); | |
if (mcpCall) { | |
// MCP call for that approval request has already been made, so we can skip it | |
continue; | |
} | |
for await (const event of callApprovedMCPToolStream( | |
item.approval_request_id, | |
mcpCallId, | |
approvalRequest, | |
mcpToolsMapping, | |
responseObject, | |
payload | |
)) { | |
yield event; | |
} | |
} | |
} | |
} | |
// Call the LLM until no new message is added to the payload. | |
// New messages can be added if the LLM calls an MCP tool that is automatically run. | |
// A maximum number of iterations is set to avoid infinite loops. | |
let previousMessageCount: number; | |
let currentMessageCount = payload.messages.length; | |
const MAX_ITERATIONS = 5; // hard-coded | |
let iterations = 0; | |
do { | |
previousMessageCount = currentMessageCount; | |
for await (const event of handleOneTurnStream(apiKey, payload, responseObject, mcpToolsMapping)) { | |
yield event; | |
} | |
currentMessageCount = payload.messages.length; | |
iterations++; | |
} while (currentMessageCount > previousMessageCount && iterations < MAX_ITERATIONS); | |
} | |
async function* listMcpToolsStream( | |
tool: McpServerParams, | |
responseObject: IncompleteResponse | |
): AsyncGenerator<ResponseStreamEvent> { | |
const outputObject: ResponseOutputItem.McpListTools = { | |
id: generateUniqueId("mcpl"), | |
type: "mcp_list_tools", | |
server_label: tool.server_label, | |
tools: [], | |
}; | |
responseObject.output.push(outputObject); | |
yield { | |
type: "response.output_item.added", | |
output_index: responseObject.output.length - 1, | |
item: outputObject, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
yield { | |
type: "response.mcp_list_tools.in_progress", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
try { | |
const mcp = await connectMcpServer(tool); | |
const mcpTools = await mcp.listTools(); | |
yield { | |
type: "response.mcp_list_tools.completed", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
outputObject.tools = mcpTools.tools.map((mcpTool) => ({ | |
input_schema: mcpTool.inputSchema, | |
name: mcpTool.name, | |
annotations: mcpTool.annotations, | |
description: mcpTool.description, | |
})); | |
yield { | |
type: "response.output_item.done", | |
output_index: responseObject.output.length - 1, | |
item: outputObject, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} catch (error) { | |
const errorMessage = `Failed to list tools from MCP server '${tool.server_label}': ${error instanceof Error ? error.message : "Unknown error"}`; | |
console.error(errorMessage); | |
yield { | |
type: "response.mcp_list_tools.failed", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
throw new Error(errorMessage); | |
} | |
} | |
/* | |
* Call LLM and stream the response. | |
*/ | |
async function* handleOneTurnStream( | |
apiKey: string | undefined, | |
payload: ChatCompletionInput, | |
responseObject: IncompleteResponse, | |
mcpToolsMapping: Record<string, McpServerParams> | |
): AsyncGenerator<ResponseStreamEvent> { | |
const stream = new InferenceClient(apiKey).chatCompletionStream(payload); | |
let previousInputTokens = responseObject.usage?.input_tokens ?? 0; | |
let previousOutputTokens = responseObject.usage?.output_tokens ?? 0; | |
let previousTotalTokens = responseObject.usage?.total_tokens ?? 0; | |
for await (const chunk of stream) { | |
if (chunk.usage) { | |
// Overwrite usage with the latest chunk's usage | |
responseObject.usage = { | |
input_tokens: previousInputTokens + chunk.usage.prompt_tokens, | |
input_tokens_details: { cached_tokens: 0 }, | |
output_tokens: previousOutputTokens + chunk.usage.completion_tokens, | |
output_tokens_details: { reasoning_tokens: 0 }, | |
total_tokens: previousTotalTokens + chunk.usage.total_tokens, | |
}; | |
} | |
const delta = chunk.choices[0].delta; | |
if (delta.content) { | |
let currentOutputItem = responseObject.output.at(-1); | |
// If start of a new message, create it | |
if (currentOutputItem?.type !== "message" || currentOutputItem?.status !== "in_progress") { | |
const outputObject: ResponseOutputMessage = { | |
id: generateUniqueId("msg"), | |
type: "message", | |
role: "assistant", | |
status: "in_progress", | |
content: [], | |
}; | |
responseObject.output.push(outputObject); | |
// Response output item added event | |
yield { | |
type: "response.output_item.added", | |
output_index: 0, | |
item: outputObject, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} | |
// If start of a new content part, create it | |
currentOutputItem = responseObject.output.at(-1) as ResponseOutputMessage; | |
if (currentOutputItem.content.length === 0) { | |
// Response content part added event | |
const contentPart: ResponseContentPartAddedEvent["part"] = { | |
type: "output_text", | |
text: "", | |
annotations: [], | |
}; | |
currentOutputItem.content.push(contentPart); | |
yield { | |
type: "response.content_part.added", | |
item_id: currentOutputItem.id, | |
output_index: responseObject.output.length - 1, | |
content_index: currentOutputItem.content.length - 1, | |
part: contentPart, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} | |
const contentPart = currentOutputItem.content.at(-1); | |
if (!contentPart || contentPart.type !== "output_text") { | |
throw new StreamingError( | |
`Not implemented: only output_text is supported in response.output[].content[].type. Got ${contentPart?.type}` | |
); | |
} | |
// Add text delta | |
contentPart.text += delta.content; | |
yield { | |
type: "response.output_text.delta", | |
item_id: currentOutputItem.id, | |
output_index: responseObject.output.length - 1, | |
content_index: currentOutputItem.content.length - 1, | |
delta: delta.content, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else if (delta.tool_calls && delta.tool_calls.length > 0) { | |
if (delta.tool_calls.length > 1) { | |
console.log("Multiple tool calls are not supported. Only the first one will be processed."); | |
} | |
let currentOutputItem = responseObject.output.at(-1); | |
if (delta.tool_calls[0].function.name) { | |
const functionName = delta.tool_calls[0].function.name; | |
// Tool call with a name => new tool call | |
let newOutputObject: | |
| ResponseOutputItem.McpCall | |
| ResponseFunctionToolCall | |
| ResponseOutputItem.McpApprovalRequest; | |
if (functionName in mcpToolsMapping) { | |
if (requiresApproval(functionName, mcpToolsMapping)) { | |
newOutputObject = { | |
id: generateUniqueId("mcpr"), | |
type: "mcp_approval_request", | |
name: functionName, | |
server_label: mcpToolsMapping[functionName].server_label, | |
arguments: "", | |
}; | |
} else { | |
newOutputObject = { | |
type: "mcp_call", | |
id: generateUniqueId("mcp"), | |
name: functionName, | |
server_label: mcpToolsMapping[functionName].server_label, | |
arguments: "", | |
}; | |
} | |
} else { | |
newOutputObject = { | |
type: "function_call", | |
id: generateUniqueId("fc"), | |
call_id: delta.tool_calls[0].id, | |
name: functionName, | |
arguments: "", | |
}; | |
} | |
// Response output item added event | |
responseObject.output.push(newOutputObject); | |
yield { | |
type: "response.output_item.added", | |
output_index: responseObject.output.length - 1, | |
item: newOutputObject, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
if (newOutputObject.type === "mcp_call") { | |
yield { | |
type: "response.mcp_call.in_progress", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
item_id: newOutputObject.id, | |
output_index: responseObject.output.length - 1, | |
}; | |
} | |
} | |
if (delta.tool_calls[0].function.arguments) { | |
// Current item is necessarily a tool call | |
currentOutputItem = responseObject.output.at(-1) as | |
| ResponseOutputItem.McpCall | |
| ResponseFunctionToolCall | |
| ResponseOutputItem.McpApprovalRequest; | |
currentOutputItem.arguments += delta.tool_calls[0].function.arguments; | |
if (currentOutputItem.type === "mcp_call" || currentOutputItem.type === "function_call") { | |
yield { | |
type: | |
currentOutputItem.type === "mcp_call" | |
? ("response.mcp_call_arguments.delta" as "response.mcp_call.arguments_delta") // bug workaround (see https://github.com/openai/openai-node/issues/1562) | |
: "response.function_call_arguments.delta", | |
item_id: currentOutputItem.id as string, | |
output_index: responseObject.output.length - 1, | |
delta: delta.tool_calls[0].function.arguments, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} | |
} | |
} | |
} | |
const lastOutputItem = responseObject.output.at(-1); | |
if (lastOutputItem) { | |
if (lastOutputItem?.type === "message") { | |
const contentPart = lastOutputItem.content.at(-1); | |
if (contentPart?.type === "output_text") { | |
yield { | |
type: "response.output_text.done", | |
item_id: lastOutputItem.id, | |
output_index: responseObject.output.length - 1, | |
content_index: lastOutputItem.content.length - 1, | |
text: contentPart.text, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
yield { | |
type: "response.content_part.done", | |
item_id: lastOutputItem.id, | |
output_index: responseObject.output.length - 1, | |
content_index: lastOutputItem.content.length - 1, | |
part: contentPart, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else { | |
throw new StreamingError("Not implemented: only output_text is supported in streaming mode."); | |
} | |
// Response output item done event | |
lastOutputItem.status = "completed"; | |
yield { | |
type: "response.output_item.done", | |
output_index: responseObject.output.length - 1, | |
item: lastOutputItem, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else if (lastOutputItem?.type === "function_call") { | |
yield { | |
type: "response.function_call_arguments.done", | |
item_id: lastOutputItem.id as string, | |
output_index: responseObject.output.length - 1, | |
arguments: lastOutputItem.arguments, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
lastOutputItem.status = "completed"; | |
yield { | |
type: "response.output_item.done", | |
output_index: responseObject.output.length - 1, | |
item: lastOutputItem, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else if (lastOutputItem?.type === "mcp_call") { | |
yield { | |
type: "response.mcp_call_arguments.done" as "response.mcp_call.arguments_done", // bug workaround (see https://github.com/openai/openai-node/issues/1562) | |
item_id: lastOutputItem.id as string, | |
output_index: responseObject.output.length - 1, | |
arguments: lastOutputItem.arguments, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
// Call MCP tool | |
const toolParams = mcpToolsMapping[lastOutputItem.name]; | |
const toolResult = await callMcpTool(toolParams, lastOutputItem.name, lastOutputItem.arguments); | |
if (toolResult.error) { | |
lastOutputItem.error = toolResult.error; | |
yield { | |
type: "response.mcp_call.failed", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else { | |
lastOutputItem.output = toolResult.output; | |
yield { | |
type: "response.mcp_call.completed", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} | |
yield { | |
type: "response.output_item.done", | |
output_index: responseObject.output.length - 1, | |
item: lastOutputItem, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
// Updating the payload for next LLM call | |
payload.messages.push( | |
{ | |
role: "assistant", | |
tool_calls: [ | |
{ | |
id: lastOutputItem.id, | |
type: "function", | |
function: { | |
name: lastOutputItem.name, | |
arguments: lastOutputItem.arguments, | |
// Hacky: type is not correct in inference.js. Will fix it but in the meantime we need to cast it. | |
// TODO: fix it in the inference.js package. Should be "arguments" and not "parameters". | |
} as unknown as ChatCompletionInputFunctionDefinition, | |
}, | |
], | |
}, | |
{ | |
role: "tool", | |
tool_call_id: lastOutputItem.id, | |
content: lastOutputItem.output | |
? lastOutputItem.output | |
: lastOutputItem.error | |
? `Error: ${lastOutputItem.error}` | |
: "", | |
} | |
); | |
} else if (lastOutputItem?.type === "mcp_approval_request") { | |
yield { | |
type: "response.output_item.done", | |
output_index: responseObject.output.length - 1, | |
item: lastOutputItem, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else { | |
throw new StreamingError( | |
`Not implemented: expected message, function_call, or mcp_call, got ${lastOutputItem?.type}` | |
); | |
} | |
} | |
} | |
/* | |
* Perform an approved MCP tool call and stream the response. | |
*/ | |
async function* callApprovedMCPToolStream( | |
approval_request_id: string, | |
mcpCallId: string, | |
approvalRequest: McpApprovalRequestParams | undefined, | |
mcpToolsMapping: Record<string, McpServerParams>, | |
responseObject: IncompleteResponse, | |
payload: ChatCompletionInput | |
): AsyncGenerator<ResponseStreamEvent> { | |
if (!approvalRequest) { | |
throw new Error(`MCP approval request '${approval_request_id}' not found`); | |
} | |
const outputObject: ResponseOutputItem.McpCall = { | |
type: "mcp_call", | |
id: mcpCallId, | |
name: approvalRequest.name, | |
server_label: approvalRequest.server_label, | |
arguments: approvalRequest.arguments, | |
}; | |
responseObject.output.push(outputObject); | |
// Response output item added event | |
yield { | |
type: "response.output_item.added", | |
output_index: responseObject.output.length - 1, | |
item: outputObject, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
yield { | |
type: "response.mcp_call.in_progress", | |
item_id: outputObject.id, | |
output_index: responseObject.output.length - 1, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
const toolParams = mcpToolsMapping[approvalRequest.name]; | |
const toolResult = await callMcpTool(toolParams, approvalRequest.name, approvalRequest.arguments); | |
if (toolResult.error) { | |
outputObject.error = toolResult.error; | |
yield { | |
type: "response.mcp_call.failed", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} else { | |
outputObject.output = toolResult.output; | |
yield { | |
type: "response.mcp_call.completed", | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
} | |
yield { | |
type: "response.output_item.done", | |
output_index: responseObject.output.length - 1, | |
item: outputObject, | |
sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, | |
}; | |
// Updating the payload for next LLM call | |
payload.messages.push( | |
{ | |
role: "assistant", | |
tool_calls: [ | |
{ | |
id: outputObject.id, | |
type: "function", | |
function: { | |
name: outputObject.name, | |
arguments: outputObject.arguments, | |
// Hacky: type is not correct in inference.js. Will fix it but in the meantime we need to cast it. | |
// TODO: fix it in the inference.js package. Should be "arguments" and not "parameters". | |
} as unknown as ChatCompletionInputFunctionDefinition, | |
}, | |
], | |
}, | |
{ | |
role: "tool", | |
tool_call_id: outputObject.id, | |
content: outputObject.output ? outputObject.output : outputObject.error ? `Error: ${outputObject.error}` : "", | |
} | |
); | |
} | |
function requiresApproval(toolName: string, mcpToolsMapping: Record<string, McpServerParams>): boolean { | |
const toolParams = mcpToolsMapping[toolName]; | |
return toolParams.require_approval === "always" | |
? true | |
: toolParams.require_approval === "never" | |
? false | |
: toolParams.require_approval.always?.tool_names?.includes(toolName) | |
? true | |
: toolParams.require_approval.never?.tool_names?.includes(toolName) | |
? false | |
: true; // behavior is undefined in specs, let's default to true | |
} | |