import grpc from '@grpc/grpc-js'; import protoLoader from '@grpc/proto-loader'; import {AutoRouter, cors, error, json} from 'itty-router'; import dotenv from 'dotenv'; import path,{ dirname } from 'path'; import { fileURLToPath } from 'url'; import {createServerAdapter} from '@whatwg-node/server'; import {createServer} from 'http'; // 加载环境变量 dotenv.config(); // 获取当前文件的目录路径(ESM 方式) const __dirname = dirname(fileURLToPath(import.meta.url)); // 初始化配置 class Config { constructor() { this.API_PREFIX = process.env.API_PREFIX || '/'; this.API_KEY = process.env.API_KEY || ''; this.MAX_RETRY_COUNT = process.env.MAX_RETRY_COUNT || 3; this.RETRY_DELAY = process.env.RETRY_DELAY || 5000; this.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'; this.COMMON_PROTO = path.join(__dirname,'..', 'protos', 'VertexInferenceService.proto') this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'; this.GPT_PROTO = path.join(__dirname,'..', 'protos', 'GPTInferenceService.proto') this.PORT = process.env.PORT || 8787; } } class GRPCHandler { constructor(protoFilePath) { // 动态加载传入的 .proto 文件路径 this.packageDefinition = protoLoader.loadSync(protoFilePath, { keepCase: true, longs: String, enums: String, defaults: true, oneofs: true }); } } const config = new Config(); // 中间件 // 添加运行回源 const { preflight, corsify } = cors({ origin: '*', allowMethods: '*', exposeHeaders: '*', }); // 添加认证 const withAuth = (request) => { if (config.API_KEY) { const authHeader = request.headers.get('Authorization'); if (!authHeader || !authHeader.startsWith('Bearer ')) { return error(401, 'Unauthorized: Missing or invalid Authorization header'); } const token = authHeader.substring(7); if (token !== config.API_KEY) { return error(403, 'Forbidden: Invalid API key'); } } }; // 返回运行信息 const logger = (res, req) => { console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms'); }; // 定义模型映射信息 const MODEL_INFO = { "claude-3-sonnet-20240229": { "provider": "anthropic", "mapping": "claude-3-sonnet@20240229" }, "claude-3-opus-20240229": { "provider": "anthropic", "mapping": "claude-3-opus@20240229" }, "claude-3-haiku-20240307": { "provider": "anthropic", "mapping": "claude-3-haiku@20240307" }, "claude-3-5-sonnet-20240620": { "provider": "anthropic", "mapping": "claude-3-5-sonnet@20240620" }, "gpt-4o-mini": { "provider": "openai", "mapping": "gpt-4o-mini" }, "gpt-4o": { "provider": "openai", "mapping": "gpt-4o" }, "gpt-4-turbo": { "provider": "openai", "mapping": "gpt-4-turbo" }, "gpt-4": { "provider": "openai", "mapping": "gpt-4" }, "gpt-3.5-turbo": { "provider": "openai", "mapping": "gpt-3.5-turbo" }, "gemini-1.5-pro": { "provider": "google", "mapping": "gemini-1.5-pro" }, "gemini-1.5-flash": { "provider": "google", "mapping": "gemini-1.5-flash" }, "chat-bison": { "provider": "pieces-os", "mapping": "chat-bison" }, "codechat-bison": { "provider": "pieces-os", "mapping": "codechat-bison" } }; // 定义路由 const router = AutoRouter({ before: [preflight], // 只保留 CORS preflight 检查 missing: () => error(404, '404 not found.'), finally: [corsify, logger], }); // 根路由 router.get('/', () => json({ service: "AI Chat Completion Proxy", usage: { endpoint: "/v1/chat/completions", method: "POST", headers: { "Content-Type": "application/json", "Authorization": "Bearer YOUR_API_KEY" }, body: { model: "One of: " + Object.keys(MODEL_INFO).join(", "), messages: [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "Hello, who are you?" } ], stream: false, temperature: 0.7, top_p: 1 } }, note: "Replace YOUR_API_KEY with your actual API key." })); // models 路由 router.get(config.API_PREFIX + '/v1/models', withAuth, () => json({ object: "list", data: Object.entries(MODEL_INFO).map(([modelId, info]) => ({ id: modelId, object: "model", created: Date.now(), owned_by: "pieces-os", permission: [], root: modelId, parent: null, mapping: info.mapping, provider: info.provider })) }) ); // chat 路由 router.post(config.API_PREFIX + '/v1/chat/completions', withAuth, (req) => handleCompletion(req)); async function GrpcToPieces(models, message, rules, stream, temperature, top_p) { // 在非GPT类型的模型中,temperature和top_p是无效的 // 使用系统的根证书 const credentials = grpc.credentials.createSsl(); let client,request; if (models.includes('gpt')){ // 加载proto文件 const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition; // 构建请求消息 request = { models: models, messages: [ {role: 0, message: rules}, // system {role: 1, message: message} // user ], temperature:temperature || 0.1, top_p:top_p ?? 1, } // 获取gRPC对象 const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt; client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials); } else { // 加载proto文件 const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition; // 构建请求消息 request = { models: models, args: { messages: { unknown: 1, message: message }, rules: rules } }; // 获取gRPC对象 const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex; client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials); } return await ConvertOpenai(client,request,models,stream); } async function messagesProcess(messages) { let rules = ''; let message = ''; for (const msg of messages) { let role = msg.role; // 格式化为字符串 const contentStr = Array.isArray(msg.content) ? msg.content .filter((item) => item.text) .map((item) => item.text) .join('') || '' : msg.content; // 判断身份 if (role === 'system') { rules += `system:${contentStr};\r\n`; } else if (['user', 'assistant'].includes(role)) { message += `${role}:${contentStr};\r\n`; } } return { rules, message }; } async function ConvertOpenai(client,request,model,stream) { for (let i = 0; i < config.MAX_RETRY_COUNT; i++) { try { if (stream) { const call = client.PredictWithStream(request); const encoder = new TextEncoder(); const ReturnStream = new ReadableStream({ start(controller) { call.on('data', (response) => { let response_code = Number(response.response_code); if (response_code === 204) { // 如果 response_code 是 204,关闭流 controller.close() call.destroy() } else if (response_code === 200) { let response_message if (model.includes('gpt')) { response_message = response.body.message_warpper.message.message; } else { response_message = response.args.args.args.message; } // 否则,将数据块加入流中 controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, model))}\n\n`)); } else { controller.error(new Error(`Error: stream chunk is not success`)); controller.close() } }) } }); return new Response(ReturnStream, { headers: { 'Content-Type': 'text/event-stream', }, }) } else { const call = await new Promise((resolve, reject) => { client.Predict(request, (err, response) => { if (err) reject(err); else resolve(response); }); }); let response_code = Number(call.response_code); if (response_code === 200) { let response_message if (model.includes('gpt')) { response_message = call.body.message_warpper.message.message; } else { response_message = call.args.args.args.message; } return new Response(JSON.stringify(ChatCompletionWithModel(response_message, model)), { headers: { 'Content-Type': 'application/json', }, }); } } } catch (err) { console.error(err); await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY)); } } return error(500, err.message); } function ChatCompletionWithModel(message, model) { return { id: 'Chat-Nekohy', object: 'chat.completion', created: Date.now(), model, usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, }, choices: [ { message: { content: message, role: 'assistant', }, index: 0, }, ], }; } function ChatCompletionStreamWithModel(text, model) { return { id: 'chatcmpl-Nekohy', object: 'chat.completion.chunk', created: 0, model, choices: [ { index: 0, delta: { content: text, }, finish_reason: null, }, ], }; } async function handleCompletion(request) { try { const { model: inputModel, messages, stream, temperature, top_p } = await request.json(); // 获取模型映射 const modelInfo = MODEL_INFO[inputModel]; if (!modelInfo) { return error(400, `Unsupported model: ${inputModel}`); } const mappedModel = modelInfo.mapping; // 解析 system 和 user/assistant 消息 const { rules, message: content } = await messagesProcess(messages); // 使用映射后的模型名称 return await GrpcToPieces(mappedModel, content, rules, stream, temperature, top_p); } catch (err) { return error(500, err.message); } } (async () => { //For Cloudflare Workers if (typeof addEventListener === 'function') return; // For Nodejs const ittyServer = createServerAdapter(router.fetch); console.log(`Listening on http://localhost:${config.PORT}`); const httpServer = createServer(ittyServer); httpServer.listen(config.PORT); })();