responses.js / src /routes /responses.ts
Wauplin's picture
Wauplin HF Staff
Upload folder using huggingface_hub
3d97d52 verified
raw
history blame
20.3 kB
import { type Response as ExpressResponse } from "express";
import { type ValidatedRequest } from "../middleware/validation.js";
import type { CreateResponseParams, McpServerParams, McpApprovalRequestParams } from "../schemas.js";
import { generateUniqueId } from "../lib/generateUniqueId.js";
import { InferenceClient } from "@huggingface/inference";
import type {
ChatCompletionInputMessage,
ChatCompletionInputMessageChunkType,
ChatCompletionInput,
} from "@huggingface/tasks";
import type {
Response,
ResponseStreamEvent,
ResponseContentPartAddedEvent,
ResponseOutputMessage,
ResponseFunctionToolCall,
ResponseOutputItem,
} from "openai/resources/responses/responses";
import type {
ChatCompletionInputTool,
ChatCompletionStreamOutputUsage,
} from "@huggingface/tasks/dist/commonjs/tasks/chat-completion/inference.js";
import { callMcpTool, connectMcpServer } from "../mcp.js";
class StreamingError extends Error {
constructor(message: string) {
super(message);
this.name = "StreamingError";
}
}
export const postCreateResponse = async (
req: ValidatedRequest<CreateResponseParams>,
res: ExpressResponse
): Promise<void> => {
const apiKey = req.headers.authorization?.split(" ")[1];
if (!apiKey) {
res.status(401).json({
success: false,
error: "Unauthorized",
});
return;
}
const client = new InferenceClient(apiKey);
const messages: ChatCompletionInputMessage[] = req.body.instructions
? [{ role: "system", content: req.body.instructions }]
: [];
if (Array.isArray(req.body.input)) {
messages.push(
...req.body.input
.map((item) => {
switch (item.type) {
case "function_call":
return {
// hacky but best fit for now
role: "assistant",
name: `function_call ${item.name} ${item.call_id}`,
content: item.arguments,
};
case "function_call_output":
return {
// hacky but best fit for now
role: "assistant",
name: `function_call_output ${item.call_id}`,
content: item.output,
};
case "message":
return {
role: item.role,
content:
typeof item.content === "string"
? item.content
: item.content
.map((content) => {
switch (content.type) {
case "input_image":
return {
type: "image_url" as ChatCompletionInputMessageChunkType,
image_url: {
url: content.image_url,
},
};
case "output_text":
return content.text
? {
type: "text" as ChatCompletionInputMessageChunkType,
text: content.text,
}
: undefined;
case "refusal":
return undefined;
case "input_text":
return {
type: "text" as ChatCompletionInputMessageChunkType,
text: content.text,
};
}
})
.filter((item) => item !== undefined),
};
case "mcp_list_tools": {
// Hacky: will be dropped by filter
return {
role: "assistant",
name: "mcp_list_tools",
content: "",
};
}
case "mcp_approval_request": {
return {
role: "assistant",
name: "mcp_approval_request",
content: `MCP approval request (${item.id}). Server: '${item.server_label}'. Tool: '${item.name}'. Arguments: '${item.arguments}'.`,
};
}
case "mcp_approval_response": {
return {
role: "assistant",
name: "mcp_approval_response",
content: `MCP approval response (${item.id}). Approved: ${item.approve}. Reason: ${item.reason}.`,
};
}
}
})
.filter((message) => message.content?.length !== 0)
);
} else {
messages.push({ role: "user", content: req.body.input });
}
const output: ResponseOutputItem[] = [];
let tools: ChatCompletionInputTool[] | undefined = [];
const mcpToolsMapping: Record<string, McpServerParams> = {};
if (req.body.tools) {
await Promise.all(
req.body.tools.map(async (tool) => {
switch (tool.type) {
case "function":
tools?.push({
type: tool.type,
function: {
name: tool.name,
parameters: tool.parameters,
description: tool.description,
strict: tool.strict,
},
});
break;
case "mcp": {
let mcpListTools: ResponseOutputItem.McpListTools | undefined;
// If MCP list tools is already in the input, use it
if (Array.isArray(req.body.input)) {
for (const item of req.body.input) {
if (item.type === "mcp_list_tools" && item.server_label === tool.server_label) {
mcpListTools = item;
console.debug(`Using MCP list tools from input for server '${tool.server_label}'`);
break;
}
}
}
// Otherwise, list tools from MCP server
if (!mcpListTools) {
try {
const mcp = await connectMcpServer(tool);
console.debug("Listing MCP tools from server");
const mcpTools = await mcp.listTools();
console.debug(`Fetched ${mcpTools.tools.length} tools from MCP server '${tool.server_label}'`);
// All tools are returned in Response object
mcpListTools = {
id: generateUniqueId("mcp_list_tools"),
type: "mcp_list_tools",
server_label: tool.server_label,
tools: mcpTools.tools.map((mcpTool) => ({
input_schema: mcpTool.inputSchema,
name: mcpTool.name,
annotations: mcpTool.annotations,
description: mcpTool.description,
})),
};
} catch (error) {
console.error("Error listing tools from MCP server", error);
mcpListTools = {
id: generateUniqueId("mcp_list_tools"),
type: "mcp_list_tools",
server_label: tool.server_label,
tools: [],
error: `Failed to list tools from MCP server '${tool.server_label}': ${error instanceof Error ? error.message : "Unknown error"}`,
};
}
output.push(mcpListTools);
}
// Only allowed tools are forwarded to the LLM
const allowedTools = tool.allowed_tools
? Array.isArray(tool.allowed_tools)
? tool.allowed_tools
: tool.allowed_tools.tool_names
: [];
if (mcpListTools?.tools) {
for (const mcpTool of mcpListTools.tools) {
if (allowedTools.length === 0 || allowedTools.includes(mcpTool.name)) {
tools?.push({
type: "function" as const,
function: {
name: mcpTool.name,
parameters: mcpTool.input_schema,
description: mcpTool.description ?? undefined,
},
});
}
mcpToolsMapping[mcpTool.name] = tool;
}
break;
}
}
}
})
);
}
if (tools.length === 0) {
tools = undefined;
}
const model = req.body.model.includes("@") ? req.body.model.split("@")[1] : req.body.model;
const provider = req.body.model.includes("@") ? req.body.model.split("@")[0] : undefined;
const payload: ChatCompletionInput = {
// main params
model: model,
provider: provider,
messages: messages,
stream: req.body.stream,
// options
max_tokens: req.body.max_output_tokens === null ? undefined : req.body.max_output_tokens,
response_format: req.body.text?.format
? {
type: req.body.text.format.type,
json_schema:
req.body.text.format.type === "json_schema"
? {
description: req.body.text.format.description,
name: req.body.text.format.name,
schema: req.body.text.format.schema,
strict: req.body.text.format.strict,
}
: undefined,
}
: undefined,
temperature: req.body.temperature,
tool_choice:
typeof req.body.tool_choice === "string"
? req.body.tool_choice
: req.body.tool_choice
? {
type: "function",
function: {
name: req.body.tool_choice.name,
},
}
: undefined,
tools,
top_p: req.body.top_p,
};
const responseObject: Omit<Response, "incomplete_details" | "output_text" | "parallel_tool_calls"> = {
created_at: Math.floor(new Date().getTime() / 1000),
error: null,
id: generateUniqueId("resp"),
instructions: req.body.instructions,
max_output_tokens: req.body.max_output_tokens,
metadata: req.body.metadata,
model: req.body.model,
object: "response",
output,
// parallel_tool_calls: req.body.parallel_tool_calls,
status: "in_progress",
text: req.body.text,
tool_choice: req.body.tool_choice ?? "auto",
tools: req.body.tools ?? [],
temperature: req.body.temperature,
top_p: req.body.top_p,
usage: {
input_tokens: 0,
input_tokens_details: { cached_tokens: 0 },
output_tokens: 0,
output_tokens_details: { reasoning_tokens: 0 },
total_tokens: 0,
},
};
// MCP approval requests => do not call LLM at all
if (Array.isArray(req.body.input)) {
for (const item of req.body.input) {
// Note: currently supporting only 1 mcp_approval_response per request
if (item.type === "mcp_approval_response" && item.approve) {
const approvalRequest = req.body.input.find(
(i) => i.type === "mcp_approval_request" && i.id === item.approval_request_id
) as McpApprovalRequestParams | undefined;
console.log("approvalRequest", approvalRequest);
if (approvalRequest) {
const toolParams = mcpToolsMapping[approvalRequest.name];
responseObject.output.push(
await callMcpTool(toolParams, approvalRequest.name, toolParams.server_label, approvalRequest.arguments)
);
responseObject.status = "completed";
res.json(responseObject);
return;
} else {
responseObject.status = "failed";
const errorMessage = `MCP approval response for approval request '${item.approval_request_id}' not found`;
console.error(errorMessage);
responseObject.error = {
code: "server_error",
message: errorMessage,
};
res.json(responseObject);
return;
}
}
}
}
// Streaming mode
if (req.body.stream) {
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Connection", "keep-alive");
let sequenceNumber = 0;
// Emit events in sequence
const emitEvent = (event: ResponseStreamEvent) => {
res.write(`data: ${JSON.stringify(event)}\n\n`);
};
try {
// Response created event
emitEvent({
type: "response.created",
response: responseObject as Response,
sequence_number: sequenceNumber++,
});
// Response in progress event
emitEvent({
type: "response.in_progress",
response: responseObject as Response,
sequence_number: sequenceNumber++,
});
const stream = client.chatCompletionStream(payload);
let usage: ChatCompletionStreamOutputUsage | undefined;
for await (const chunk of stream) {
if (chunk.usage) {
usage = chunk.usage;
}
if (chunk.choices[0].delta.content) {
if (responseObject.output.length === 0) {
const outputObject: ResponseOutputMessage = {
id: generateUniqueId("msg"),
type: "message",
role: "assistant",
status: "in_progress",
content: [],
};
responseObject.output = [outputObject];
// Response output item added event
emitEvent({
type: "response.output_item.added",
output_index: 0,
item: outputObject,
sequence_number: sequenceNumber++,
});
}
const outputObject = responseObject.output.at(-1);
if (!outputObject || outputObject.type !== "message") {
throw new StreamingError("Not implemented: only single output item type is supported in streaming mode.");
}
if (outputObject.content.length === 0) {
// Response content part added event
const contentPart: ResponseContentPartAddedEvent["part"] = {
type: "output_text",
text: "",
annotations: [],
};
outputObject.content.push(contentPart);
emitEvent({
type: "response.content_part.added",
item_id: outputObject.id,
output_index: 0,
content_index: 0,
part: contentPart,
sequence_number: sequenceNumber++,
});
}
const contentPart = outputObject.content.at(-1);
if (!contentPart || contentPart.type !== "output_text") {
throw new StreamingError("Not implemented: only output_text is supported in streaming mode.");
}
// Add text delta
contentPart.text += chunk.choices[0].delta.content;
emitEvent({
type: "response.output_text.delta",
item_id: outputObject.id,
output_index: 0,
content_index: 0,
delta: chunk.choices[0].delta.content,
sequence_number: sequenceNumber++,
});
} else if (chunk.choices[0].delta.tool_calls && chunk.choices[0].delta.tool_calls.length > 0) {
if (chunk.choices[0].delta.tool_calls.length > 1) {
throw new StreamingError("Not implemented: only single tool call is supported in streaming mode.");
}
if (responseObject.output.length === 0) {
if (!chunk.choices[0].delta.tool_calls[0].function.name) {
throw new StreamingError("Tool call function name is required.");
}
const outputObject: ResponseFunctionToolCall = {
type: "function_call",
id: generateUniqueId("fc"),
call_id: chunk.choices[0].delta.tool_calls[0].id,
name: chunk.choices[0].delta.tool_calls[0].function.name,
arguments: "",
};
responseObject.output = [outputObject];
// Response output item added event
emitEvent({
type: "response.output_item.added",
output_index: 0,
item: outputObject,
sequence_number: sequenceNumber++,
});
}
const outputObject = responseObject.output.at(-1);
if (!outputObject || !outputObject.id || outputObject.type !== "function_call") {
throw new StreamingError("Not implemented: can only support single output item type in streaming mode.");
}
outputObject.arguments += chunk.choices[0].delta.tool_calls[0].function.arguments;
emitEvent({
type: "response.function_call_arguments.delta",
item_id: outputObject.id,
output_index: 0,
delta: chunk.choices[0].delta.tool_calls[0].function.arguments,
sequence_number: sequenceNumber++,
});
}
}
const lastOutputItem = responseObject.output.at(-1);
if (lastOutputItem) {
if (lastOutputItem?.type === "message") {
const contentPart = lastOutputItem.content.at(-1);
if (contentPart?.type === "output_text") {
emitEvent({
type: "response.output_text.done",
item_id: lastOutputItem.id,
output_index: responseObject.output.length - 1,
content_index: lastOutputItem.content.length - 1,
text: contentPart.text,
sequence_number: sequenceNumber++,
});
emitEvent({
type: "response.content_part.done",
item_id: lastOutputItem.id,
output_index: responseObject.output.length - 1,
content_index: lastOutputItem.content.length - 1,
part: contentPart,
sequence_number: sequenceNumber++,
});
} else {
throw new StreamingError("Not implemented: only output_text is supported in streaming mode.");
}
// Response output item done event
lastOutputItem.status = "completed";
emitEvent({
type: "response.output_item.done",
output_index: responseObject.output.length - 1,
item: lastOutputItem,
sequence_number: sequenceNumber++,
});
} else if (lastOutputItem?.type === "function_call") {
if (!lastOutputItem.id) {
throw new StreamingError("Function call id is required.");
}
emitEvent({
type: "response.function_call_arguments.done",
item_id: lastOutputItem.id,
output_index: responseObject.output.length - 1,
arguments: lastOutputItem.arguments,
sequence_number: sequenceNumber++,
});
lastOutputItem.status = "completed";
emitEvent({
type: "response.output_item.done",
output_index: responseObject.output.length - 1,
item: lastOutputItem,
sequence_number: sequenceNumber++,
});
} else {
throw new StreamingError("Not implemented: only message output is supported in streaming mode.");
}
}
// Response completed event
responseObject.status = "completed";
if (usage) {
responseObject.usage = {
input_tokens: usage.prompt_tokens,
input_tokens_details: { cached_tokens: 0 },
output_tokens: usage.completion_tokens,
output_tokens_details: { reasoning_tokens: 0 },
total_tokens: usage.total_tokens,
};
}
emitEvent({
type: "response.completed",
response: responseObject as Response,
sequence_number: sequenceNumber++,
});
} catch (streamError) {
console.error("Error in streaming chat completion:", streamError);
let message = "An error occurred while streaming from inference server.";
if (streamError instanceof StreamingError) {
message = streamError.message;
} else if (
typeof streamError === "object" &&
streamError &&
"message" in streamError &&
typeof streamError.message === "string"
) {
message = streamError.message;
}
responseObject.status = "failed";
responseObject.error = {
code: "server_error",
message,
};
emitEvent({
type: "response.failed",
response: responseObject as Response,
sequence_number: sequenceNumber++,
});
}
res.end();
return;
}
try {
const chatCompletionResponse = await client.chatCompletion(payload);
responseObject.status = "completed";
for (const choice of chatCompletionResponse.choices) {
if (choice.message.content) {
responseObject.output.push({
id: generateUniqueId("msg"),
type: "message",
role: "assistant",
status: "completed",
content: [
{
type: "output_text",
text: choice.message.content,
annotations: [],
},
],
});
}
if (choice.message.tool_calls) {
for (const toolCall of choice.message.tool_calls) {
if (toolCall.function.name in mcpToolsMapping) {
const toolParams = mcpToolsMapping[toolCall.function.name];
// Check if approval is required
const approvalRequired =
toolParams.require_approval === "always"
? true
: toolParams.require_approval === "never"
? false
: toolParams.require_approval.always?.tool_names?.includes(toolCall.function.name)
? true
: toolParams.require_approval.never?.tool_names?.includes(toolCall.function.name)
? false
: true; // behavior is undefined in specs, let's default to
if (approvalRequired) {
// TODO: Implement approval logic
console.log(`Requesting approval for MCP tool '${toolCall.function.name}'`);
responseObject.output.push({
type: "mcp_approval_request",
id: generateUniqueId("mcp_approval_request"),
name: toolCall.function.name,
server_label: toolParams.server_label,
arguments: toolCall.function.arguments,
});
} else {
responseObject.output.push(
await callMcpTool(
toolParams,
toolCall.function.name,
toolParams.server_label,
toolCall.function.arguments
)
);
}
} else {
responseObject.output.push({
type: "function_call",
id: generateUniqueId("fc"),
call_id: toolCall.id,
name: toolCall.function.name,
arguments: toolCall.function.arguments,
status: "completed",
});
}
}
}
}
responseObject.usage = {
input_tokens: chatCompletionResponse.usage.prompt_tokens,
input_tokens_details: { cached_tokens: 0 },
output_tokens: chatCompletionResponse.usage.completion_tokens,
output_tokens_details: { reasoning_tokens: 0 },
total_tokens: chatCompletionResponse.usage.total_tokens,
};
res.json(responseObject);
} catch (error) {
console.error(error);
res.status(500).json({
success: false,
error: error instanceof Error ? error.message : "Unknown error",
});
}
};