import { type Response as ExpressResponse } from "express"; import { type ValidatedRequest } from "../middleware/validation.js"; import type { CreateResponseParams, McpServerParams, McpApprovalRequestParams } from "../schemas.js"; import { generateUniqueId } from "../lib/generateUniqueId.js"; import { InferenceClient } from "@huggingface/inference"; import type { ChatCompletionInputMessage, ChatCompletionInputMessageChunkType, ChatCompletionInput, } from "@huggingface/tasks"; import type { Response, ResponseStreamEvent, ResponseContentPartAddedEvent, ResponseOutputMessage, ResponseFunctionToolCall, ResponseOutputItem, } from "openai/resources/responses/responses"; import type { ChatCompletionInputFunctionDefinition, ChatCompletionInputTool, } from "@huggingface/tasks/dist/commonjs/tasks/chat-completion/inference.js"; import { callMcpTool, connectMcpServer } from "../mcp.js"; class StreamingError extends Error { constructor(message: string) { super(message); this.name = "StreamingError"; } } type IncompleteResponse = Omit; const SEQUENCE_NUMBER_PLACEHOLDER = -1; export const postCreateResponse = async ( req: ValidatedRequest, res: ExpressResponse ): Promise => { // To avoid duplicated code, we run all requests as stream. const events = runCreateResponseStream(req, res); // Then we return in the correct format depending on the user 'stream' flag. if (req.body.stream) { res.setHeader("Content-Type", "text/event-stream"); res.setHeader("Connection", "keep-alive"); console.debug("Stream request"); for await (const event of events) { console.debug(`Event #${event.sequence_number}: ${event.type}`); res.write(`data: ${JSON.stringify(event)}\n\n`); } res.end(); } else { console.debug("Non-stream request"); for await (const event of events) { if (event.type === "response.completed" || event.type === "response.failed") { console.debug(event.type); res.json(event.response); } } } }; /* * Top-level stream. * * Handles response lifecycle + execute inner logic (MCP list tools, MCP tool calls, LLM call, etc.). * Handles sequenceNumber by overwriting it in the events. */ async function* runCreateResponseStream( req: ValidatedRequest, res: ExpressResponse ): AsyncGenerator { let sequenceNumber = 0; // Prepare response object that will be iteratively populated const responseObject: IncompleteResponse = { created_at: Math.floor(new Date().getTime() / 1000), error: null, id: generateUniqueId("resp"), instructions: req.body.instructions, max_output_tokens: req.body.max_output_tokens, metadata: req.body.metadata, model: req.body.model, object: "response", output: [], // parallel_tool_calls: req.body.parallel_tool_calls, status: "in_progress", text: req.body.text, tool_choice: req.body.tool_choice ?? "auto", tools: req.body.tools ?? [], temperature: req.body.temperature, top_p: req.body.top_p, usage: { input_tokens: 0, input_tokens_details: { cached_tokens: 0 }, output_tokens: 0, output_tokens_details: { reasoning_tokens: 0 }, total_tokens: 0, }, }; // Response created event yield { type: "response.created", response: responseObject as Response, sequence_number: sequenceNumber++, }; // Response in progress event yield { type: "response.in_progress", response: responseObject as Response, sequence_number: sequenceNumber++, }; // Any events (LLM call, MCP call, list tools, etc.) try { for await (const event of innerRunStream(req, res, responseObject)) { yield { ...event, sequence_number: sequenceNumber++ }; } } catch (error) { // Error event => stop console.error("Error in stream:", error); const message = typeof error === "object" && error && "message" in error && typeof error.message === "string" ? error.message : "An error occurred in stream"; responseObject.status = "failed"; responseObject.error = { code: "server_error", message, }; yield { type: "response.failed", response: responseObject as Response, sequence_number: sequenceNumber++, }; return; } // Response completed event responseObject.status = "completed"; yield { type: "response.completed", response: responseObject as Response, sequence_number: sequenceNumber++, }; } async function* innerRunStream( req: ValidatedRequest, res: ExpressResponse, responseObject: IncompleteResponse ): AsyncGenerator { // Retrieve API key from headers const apiKey = req.headers.authorization?.split(" ")[1]; if (!apiKey) { res.status(401).json({ success: false, error: "Unauthorized", }); return; } // List MCP tools from server (if required) + prepare tools for the LLM let tools: ChatCompletionInputTool[] | undefined = []; const mcpToolsMapping: Record = {}; if (req.body.tools) { for (const tool of req.body.tools) { switch (tool.type) { case "function": tools?.push({ type: tool.type, function: { name: tool.name, parameters: tool.parameters, description: tool.description, strict: tool.strict, }, }); break; case "mcp": { let mcpListTools: ResponseOutputItem.McpListTools | undefined; // If MCP list tools is already in the input, use it if (Array.isArray(req.body.input)) { for (const item of req.body.input) { if (item.type === "mcp_list_tools" && item.server_label === tool.server_label) { mcpListTools = item; console.debug(`Using MCP list tools from input for server '${tool.server_label}'`); break; } } } // Otherwise, list tools from MCP server if (!mcpListTools) { for await (const event of listMcpToolsStream(tool, responseObject)) { yield event; } mcpListTools = responseObject.output.at(-1) as ResponseOutputItem.McpListTools; } // Only allowed tools are forwarded to the LLM const allowedTools = tool.allowed_tools ? Array.isArray(tool.allowed_tools) ? tool.allowed_tools : tool.allowed_tools.tool_names : []; if (mcpListTools?.tools) { for (const mcpTool of mcpListTools.tools) { if (allowedTools.length === 0 || allowedTools.includes(mcpTool.name)) { tools?.push({ type: "function" as const, function: { name: mcpTool.name, parameters: mcpTool.input_schema, description: mcpTool.description ?? undefined, }, }); } mcpToolsMapping[mcpTool.name] = tool; } break; } } } } } if (tools.length === 0) { tools = undefined; } // Prepare payload for the LLM // Resolve model and provider const model = req.body.model.includes("@") ? req.body.model.split("@")[1] : req.body.model; const provider = req.body.model.includes("@") ? req.body.model.split("@")[0] : undefined; // Format input to Chat Completion format const messages: ChatCompletionInputMessage[] = req.body.instructions ? [{ role: "system", content: req.body.instructions }] : []; if (Array.isArray(req.body.input)) { messages.push( ...req.body.input .map((item) => { switch (item.type) { case "function_call": return { // hacky but best fit for now role: "assistant", name: `function_call ${item.name} ${item.call_id}`, content: item.arguments, }; case "function_call_output": return { // hacky but best fit for now role: "assistant", name: `function_call_output ${item.call_id}`, content: item.output, }; case "message": return { role: item.role, content: typeof item.content === "string" ? item.content : item.content .map((content) => { switch (content.type) { case "input_image": return { type: "image_url" as ChatCompletionInputMessageChunkType, image_url: { url: content.image_url, }, }; case "output_text": return content.text ? { type: "text" as ChatCompletionInputMessageChunkType, text: content.text, } : undefined; case "refusal": return undefined; case "input_text": return { type: "text" as ChatCompletionInputMessageChunkType, text: content.text, }; } }) .filter((item) => item !== undefined), }; case "mcp_list_tools": { // Hacky: will be dropped by filter since tools are passed as separate objects return { role: "assistant", name: "mcp_list_tools", content: "", }; } case "mcp_call": { return { role: "assistant", name: "mcp_call", content: `MCP call (${item.id}). Server: '${item.server_label}'. Tool: '${item.name}'. Arguments: '${item.arguments}'.`, }; } case "mcp_approval_request": { return { role: "assistant", name: "mcp_approval_request", content: `MCP approval request (${item.id}). Server: '${item.server_label}'. Tool: '${item.name}'. Arguments: '${item.arguments}'.`, }; } case "mcp_approval_response": { return { role: "assistant", name: "mcp_approval_response", content: `MCP approval response (${item.id}). Approved: ${item.approve}. Reason: ${item.reason}.`, }; } } }) .filter((message) => message.content?.length !== 0) ); } else { messages.push({ role: "user", content: req.body.input }); } // Prepare payload for the LLM const payload: ChatCompletionInput = { // main params model, provider, messages, stream: req.body.stream, // options max_tokens: req.body.max_output_tokens === null ? undefined : req.body.max_output_tokens, response_format: req.body.text?.format ? { type: req.body.text.format.type, json_schema: req.body.text.format.type === "json_schema" ? { description: req.body.text.format.description, name: req.body.text.format.name, schema: req.body.text.format.schema, strict: req.body.text.format.strict, } : undefined, } : undefined, temperature: req.body.temperature, tool_choice: typeof req.body.tool_choice === "string" ? req.body.tool_choice : req.body.tool_choice ? { type: "function", function: { name: req.body.tool_choice.name, }, } : undefined, tools, top_p: req.body.top_p, }; // If MCP approval requests => execute them and return (no LLM call) if (Array.isArray(req.body.input)) { for (const item of req.body.input) { if (item.type === "mcp_approval_response" && item.approve) { const approvalRequest = req.body.input.find( (i) => i.type === "mcp_approval_request" && i.id === item.approval_request_id ) as McpApprovalRequestParams | undefined; const mcpCallId = "mcp_" + item.approval_request_id.split("_")[1]; const mcpCall = req.body.input.find((i) => i.type === "mcp_call" && i.id === mcpCallId); if (mcpCall) { // MCP call for that approval request has already been made, so we can skip it continue; } for await (const event of callApprovedMCPToolStream( item.approval_request_id, mcpCallId, approvalRequest, mcpToolsMapping, responseObject, payload )) { yield event; } } } } // Call the LLM until no new message is added to the payload. // New messages can be added if the LLM calls an MCP tool that is automatically run. // A maximum number of iterations is set to avoid infinite loops. let previousMessageCount: number; let currentMessageCount = payload.messages.length; const MAX_ITERATIONS = 5; // hard-coded let iterations = 0; do { previousMessageCount = currentMessageCount; for await (const event of handleOneTurnStream(apiKey, payload, responseObject, mcpToolsMapping)) { yield event; } currentMessageCount = payload.messages.length; iterations++; } while (currentMessageCount > previousMessageCount && iterations < MAX_ITERATIONS); } async function* listMcpToolsStream( tool: McpServerParams, responseObject: IncompleteResponse ): AsyncGenerator { const outputObject: ResponseOutputItem.McpListTools = { id: generateUniqueId("mcpl"), type: "mcp_list_tools", server_label: tool.server_label, tools: [], }; responseObject.output.push(outputObject); yield { type: "response.output_item.added", output_index: responseObject.output.length - 1, item: outputObject, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; yield { type: "response.mcp_list_tools.in_progress", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; try { const mcp = await connectMcpServer(tool); const mcpTools = await mcp.listTools(); yield { type: "response.mcp_list_tools.completed", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; outputObject.tools = mcpTools.tools.map((mcpTool) => ({ input_schema: mcpTool.inputSchema, name: mcpTool.name, annotations: mcpTool.annotations, description: mcpTool.description, })); yield { type: "response.output_item.done", output_index: responseObject.output.length - 1, item: outputObject, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } catch (error) { const errorMessage = `Failed to list tools from MCP server '${tool.server_label}': ${error instanceof Error ? error.message : "Unknown error"}`; console.error(errorMessage); yield { type: "response.mcp_list_tools.failed", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; throw new Error(errorMessage); } } /* * Call LLM and stream the response. */ async function* handleOneTurnStream( apiKey: string | undefined, payload: ChatCompletionInput, responseObject: IncompleteResponse, mcpToolsMapping: Record ): AsyncGenerator { const stream = new InferenceClient(apiKey).chatCompletionStream(payload); let previousInputTokens = responseObject.usage?.input_tokens ?? 0; let previousOutputTokens = responseObject.usage?.output_tokens ?? 0; let previousTotalTokens = responseObject.usage?.total_tokens ?? 0; for await (const chunk of stream) { if (chunk.usage) { // Overwrite usage with the latest chunk's usage responseObject.usage = { input_tokens: previousInputTokens + chunk.usage.prompt_tokens, input_tokens_details: { cached_tokens: 0 }, output_tokens: previousOutputTokens + chunk.usage.completion_tokens, output_tokens_details: { reasoning_tokens: 0 }, total_tokens: previousTotalTokens + chunk.usage.total_tokens, }; } const delta = chunk.choices[0].delta; if (delta.content) { let currentOutputItem = responseObject.output.at(-1); // If start of a new message, create it if (currentOutputItem?.type !== "message" || currentOutputItem?.status !== "in_progress") { const outputObject: ResponseOutputMessage = { id: generateUniqueId("msg"), type: "message", role: "assistant", status: "in_progress", content: [], }; responseObject.output.push(outputObject); // Response output item added event yield { type: "response.output_item.added", output_index: 0, item: outputObject, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } // If start of a new content part, create it currentOutputItem = responseObject.output.at(-1) as ResponseOutputMessage; if (currentOutputItem.content.length === 0) { // Response content part added event const contentPart: ResponseContentPartAddedEvent["part"] = { type: "output_text", text: "", annotations: [], }; currentOutputItem.content.push(contentPart); yield { type: "response.content_part.added", item_id: currentOutputItem.id, output_index: responseObject.output.length - 1, content_index: currentOutputItem.content.length - 1, part: contentPart, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } const contentPart = currentOutputItem.content.at(-1); if (!contentPart || contentPart.type !== "output_text") { throw new StreamingError( `Not implemented: only output_text is supported in response.output[].content[].type. Got ${contentPart?.type}` ); } // Add text delta contentPart.text += delta.content; yield { type: "response.output_text.delta", item_id: currentOutputItem.id, output_index: responseObject.output.length - 1, content_index: currentOutputItem.content.length - 1, delta: delta.content, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else if (delta.tool_calls && delta.tool_calls.length > 0) { if (delta.tool_calls.length > 1) { console.log("Multiple tool calls are not supported. Only the first one will be processed."); } let currentOutputItem = responseObject.output.at(-1); if (delta.tool_calls[0].function.name) { const functionName = delta.tool_calls[0].function.name; // Tool call with a name => new tool call let newOutputObject: | ResponseOutputItem.McpCall | ResponseFunctionToolCall | ResponseOutputItem.McpApprovalRequest; if (functionName in mcpToolsMapping) { if (requiresApproval(functionName, mcpToolsMapping)) { newOutputObject = { id: generateUniqueId("mcpr"), type: "mcp_approval_request", name: functionName, server_label: mcpToolsMapping[functionName].server_label, arguments: "", }; } else { newOutputObject = { type: "mcp_call", id: generateUniqueId("mcp"), name: functionName, server_label: mcpToolsMapping[functionName].server_label, arguments: "", }; } } else { newOutputObject = { type: "function_call", id: generateUniqueId("fc"), call_id: delta.tool_calls[0].id, name: functionName, arguments: "", }; } // Response output item added event responseObject.output.push(newOutputObject); yield { type: "response.output_item.added", output_index: responseObject.output.length - 1, item: newOutputObject, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; if (newOutputObject.type === "mcp_call") { yield { type: "response.mcp_call.in_progress", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, item_id: newOutputObject.id, output_index: responseObject.output.length - 1, }; } } if (delta.tool_calls[0].function.arguments) { // Current item is necessarily a tool call currentOutputItem = responseObject.output.at(-1) as | ResponseOutputItem.McpCall | ResponseFunctionToolCall | ResponseOutputItem.McpApprovalRequest; currentOutputItem.arguments += delta.tool_calls[0].function.arguments; if (currentOutputItem.type === "mcp_call" || currentOutputItem.type === "function_call") { yield { type: currentOutputItem.type === "mcp_call" ? ("response.mcp_call_arguments.delta" as "response.mcp_call.arguments_delta") // bug workaround (see https://github.com/openai/openai-node/issues/1562) : "response.function_call_arguments.delta", item_id: currentOutputItem.id as string, output_index: responseObject.output.length - 1, delta: delta.tool_calls[0].function.arguments, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } } } } const lastOutputItem = responseObject.output.at(-1); if (lastOutputItem) { if (lastOutputItem?.type === "message") { const contentPart = lastOutputItem.content.at(-1); if (contentPart?.type === "output_text") { yield { type: "response.output_text.done", item_id: lastOutputItem.id, output_index: responseObject.output.length - 1, content_index: lastOutputItem.content.length - 1, text: contentPart.text, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; yield { type: "response.content_part.done", item_id: lastOutputItem.id, output_index: responseObject.output.length - 1, content_index: lastOutputItem.content.length - 1, part: contentPart, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else { throw new StreamingError("Not implemented: only output_text is supported in streaming mode."); } // Response output item done event lastOutputItem.status = "completed"; yield { type: "response.output_item.done", output_index: responseObject.output.length - 1, item: lastOutputItem, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else if (lastOutputItem?.type === "function_call") { yield { type: "response.function_call_arguments.done", item_id: lastOutputItem.id as string, output_index: responseObject.output.length - 1, arguments: lastOutputItem.arguments, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; lastOutputItem.status = "completed"; yield { type: "response.output_item.done", output_index: responseObject.output.length - 1, item: lastOutputItem, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else if (lastOutputItem?.type === "mcp_call") { yield { type: "response.mcp_call_arguments.done" as "response.mcp_call.arguments_done", // bug workaround (see https://github.com/openai/openai-node/issues/1562) item_id: lastOutputItem.id as string, output_index: responseObject.output.length - 1, arguments: lastOutputItem.arguments, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; // Call MCP tool const toolParams = mcpToolsMapping[lastOutputItem.name]; const toolResult = await callMcpTool(toolParams, lastOutputItem.name, lastOutputItem.arguments); if (toolResult.error) { lastOutputItem.error = toolResult.error; yield { type: "response.mcp_call.failed", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else { lastOutputItem.output = toolResult.output; yield { type: "response.mcp_call.completed", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } yield { type: "response.output_item.done", output_index: responseObject.output.length - 1, item: lastOutputItem, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; // Updating the payload for next LLM call payload.messages.push( { role: "assistant", tool_calls: [ { id: lastOutputItem.id, type: "function", function: { name: lastOutputItem.name, arguments: lastOutputItem.arguments, // Hacky: type is not correct in inference.js. Will fix it but in the meantime we need to cast it. // TODO: fix it in the inference.js package. Should be "arguments" and not "parameters". } as unknown as ChatCompletionInputFunctionDefinition, }, ], }, { role: "tool", tool_call_id: lastOutputItem.id, content: lastOutputItem.output ? lastOutputItem.output : lastOutputItem.error ? `Error: ${lastOutputItem.error}` : "", } ); } else if (lastOutputItem?.type === "mcp_approval_request") { yield { type: "response.output_item.done", output_index: responseObject.output.length - 1, item: lastOutputItem, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else { throw new StreamingError( `Not implemented: expected message, function_call, or mcp_call, got ${lastOutputItem?.type}` ); } } } /* * Perform an approved MCP tool call and stream the response. */ async function* callApprovedMCPToolStream( approval_request_id: string, mcpCallId: string, approvalRequest: McpApprovalRequestParams | undefined, mcpToolsMapping: Record, responseObject: IncompleteResponse, payload: ChatCompletionInput ): AsyncGenerator { if (!approvalRequest) { throw new Error(`MCP approval request '${approval_request_id}' not found`); } const outputObject: ResponseOutputItem.McpCall = { type: "mcp_call", id: mcpCallId, name: approvalRequest.name, server_label: approvalRequest.server_label, arguments: approvalRequest.arguments, }; responseObject.output.push(outputObject); // Response output item added event yield { type: "response.output_item.added", output_index: responseObject.output.length - 1, item: outputObject, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; yield { type: "response.mcp_call.in_progress", item_id: outputObject.id, output_index: responseObject.output.length - 1, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; const toolParams = mcpToolsMapping[approvalRequest.name]; const toolResult = await callMcpTool(toolParams, approvalRequest.name, approvalRequest.arguments); if (toolResult.error) { outputObject.error = toolResult.error; yield { type: "response.mcp_call.failed", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } else { outputObject.output = toolResult.output; yield { type: "response.mcp_call.completed", sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; } yield { type: "response.output_item.done", output_index: responseObject.output.length - 1, item: outputObject, sequence_number: SEQUENCE_NUMBER_PLACEHOLDER, }; // Updating the payload for next LLM call payload.messages.push( { role: "assistant", tool_calls: [ { id: outputObject.id, type: "function", function: { name: outputObject.name, arguments: outputObject.arguments, // Hacky: type is not correct in inference.js. Will fix it but in the meantime we need to cast it. // TODO: fix it in the inference.js package. Should be "arguments" and not "parameters". } as unknown as ChatCompletionInputFunctionDefinition, }, ], }, { role: "tool", tool_call_id: outputObject.id, content: outputObject.output ? outputObject.output : outputObject.error ? `Error: ${outputObject.error}` : "", } ); } function requiresApproval(toolName: string, mcpToolsMapping: Record): boolean { const toolParams = mcpToolsMapping[toolName]; return toolParams.require_approval === "always" ? true : toolParams.require_approval === "never" ? false : toolParams.require_approval.always?.tool_names?.includes(toolName) ? true : toolParams.require_approval.never?.tool_names?.includes(toolName) ? false : true; // behavior is undefined in specs, let's default to true }