pcs / api /index.js
smgc's picture
Update api/index.js
e09949e verified
import grpc from '@huayue/grpc-js'
import protoLoader from '@grpc/proto-loader'
import { AutoRouter, cors, error, json } from 'itty-router'
import dotenv from 'dotenv'
import path, { dirname } from 'path'
import { fileURLToPath } from 'url'
import { createServerAdapter } from '@whatwg-node/server'
import { createServer } from 'http'
// 加载环境变量
dotenv.config();
// 获取当前文件的目录路径(ESM 方式)
const __dirname = dirname(fileURLToPath(import.meta.url));
// 初始化配置
class Config {
constructor() {
this.API_PREFIX = process.env.API_PREFIX || '/'
this.API_KEY = process.env.API_KEY || ''
this.MAX_RETRY_COUNT = process.env.MAX_RETRY_COUNT || 3
this.RETRY_DELAY = process.env.RETRY_DELAY || 5000
this.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'
this.COMMON_PROTO = path.join(__dirname, '..', 'protos', 'VertexInferenceService.proto')
this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'
this.GPT_PROTO = path.join(__dirname, '..', 'protos', 'GPTInferenceService.proto')
this.PORT = process.env.PORT || 8787
// 添加支持的模型列表
this.SUPPORTED_MODELS = process.env.SUPPORTED_MODELS || [
'gpt-4o-mini',
'gpt-4o',
'gpt-4-turbo',
'gpt-4',
'gpt-3.5-turbo',
'claude-3-sonnet@20240229',
'claude-3-opus@20240229',
'claude-3-haiku@20240307',
'claude-3-5-sonnet@20240620',
'gemini-1.5-flash',
'gemini-1.5-pro',
'chat-bison',
'codechat-bison',
]
}
// 添加模型验证方法
isValidModel(model) {
// 处理 Claude 模型的特殊格式
const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/
const matchInput = model.match(RegexInput)
const normalizedModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : model
return this.SUPPORTED_MODELS.includes(normalizedModel)
}
}
class GRPCHandler {
constructor(protoFilePath) {
// 动态加载传入的 .proto 文件路径
this.packageDefinition = protoLoader.loadSync(protoFilePath, {
keepCase: true,
longs: String,
enums: String,
defaults: true,
oneofs: true,
})
}
}
const config = new Config()
// 中间件
// 添加运行回源
const { preflight, corsify } = cors({
origin: '*',
allowMethods: '*',
exposeHeaders: '*',
})
// 添加认证
const withAuth = (request) => {
if (config.API_KEY) {
const authHeader = request.headers.get('Authorization')
if (!authHeader || !authHeader.startsWith('Bearer ')) {
return error(401, 'Unauthorized: Missing or invalid Authorization header')
}
const token = authHeader.substring(7)
if (token !== config.API_KEY) {
return error(403, 'Forbidden: Invalid API key')
}
}
}
// 返回运行信息
const logger = (res, req) => {
console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms')
}
const router = AutoRouter({
before: [preflight],
missing: () => error(404, '404 not found.'),
finally: [corsify, logger],
})
// Router路径
router.get('/', () => json({
service: "AI Chat Completion Proxy",
usage: {
endpoint: "/v1/chat/completions",
method: "POST",
headers: {
"Content-Type": "application/json",
"Authorization": "Bearer YOUR_API_KEY"
},
body: {
model: "One of: gpt-4o-mini, gpt-4o, gpt-4-turbo, gpt-4, gpt-3.5-turbo, claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-haiku-20240307, claude-3-5-sonnet-20240620, gemini-1.5-flash, gemini-1.5-pro, chat-bison, codechat-bison",
messages: [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Hello, who are you?" }
],
stream: false,
temperature: 0.7,
top_p: 1
}
},
availableModels: [
"gpt-4o-mini",
"gpt-4o",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
"gemini-1.5-flash",
"gemini-1.5-pro",
"chat-bison",
"codechat-bison"
],
note: "Replace YOUR_API_KEY with your actual API key."
}));
router.get(config.API_PREFIX + '/v1/models', withAuth, () =>
json({
object: 'list',
data: [
{ id: 'gpt-4o-mini', object: 'model', owned_by: 'pieces-os' },
{ id: 'gpt-4o', object: 'model', owned_by: 'pieces-os' },
{ id: 'gpt-4-turbo', object: 'model', owned_by: 'pieces-os' },
{ id: 'gpt-4', object: 'model', owned_by: 'pieces-os' },
{ id: 'gpt-3.5-turbo', object: 'model', owned_by: 'pieces-os' },
{ id: 'claude-3-sonnet-20240229', object: 'model', owned_by: 'pieces-os' },
{ id: 'claude-3-opus-20240229', object: 'model', owned_by: 'pieces-os' },
{ id: 'claude-3-haiku-20240307', object: 'model', owned_by: 'pieces-os' },
{ id: 'claude-3-5-sonnet-20240620', object: 'model', owned_by: 'pieces-os' },
{ id: 'gemini-1.5-flash', object: 'model', owned_by: 'pieces-os' },
{ id: 'gemini-1.5-pro', object: 'model', owned_by: 'pieces-os' },
{ id: 'chat-bison', object: 'model', owned_by: 'pieces-os' },
{ id: 'codechat-bison', object: 'model', owned_by: 'pieces-os' },
],
}),
)
router.post(config.API_PREFIX + '/v1/chat/completions', withAuth, (req) => handleCompletion(req));
async function GrpcToPieces(inputModel, OriginModel, message, rules, stream, temperature, top_p) {
// 在非GPT类型的模型中,temperature和top_p是无效的
// 使用系统的根证书
const credentials = grpc.credentials.createSsl()
let client, request
if (inputModel.includes('gpt')) {
// 加载proto文件
const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition
// 构建请求消息
request = {
models: inputModel,
messages: [
{ role: 0, message: rules }, // system
{ role: 1, message: message }, // user
],
temperature: temperature || 0.1,
top_p: top_p ?? 1,
}
// 获取gRPC对象
const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt
client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials)
} else {
// 加载proto文件
const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition
// 构建请求消息
request = {
models: inputModel,
args: {
messages: {
unknown: 1,
message: message,
},
rules: rules,
},
}
// 获取gRPC对象
const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex
client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials)
}
return await ConvertOpenai(client, request, inputModel, OriginModel, stream)
}
async function messagesProcess(messages) {
let rules = ''
let message = ''
for (const msg of messages) {
let role = msg.role
// 格式化为字符串
const contentStr = Array.isArray(msg.content)
? msg.content
.filter((item) => item.text)
.map((item) => item.text)
.join('') || ''
: msg.content
// 判断身份
if (role === 'system') {
rules += `system:${contentStr};\r\n`
} else if (['user', 'assistant'].includes(role)) {
message += `${role}:${contentStr};\r\n`
}
}
return { rules, message }
}
async function ConvertOpenai(client, request, inputModel, OriginModel, stream) {
const metadata = new grpc.Metadata()
metadata.set('User-Agent', 'dart-grpc/2.0.0')
for (let i = 0; i < config.MAX_RETRY_COUNT; i++) {
try {
if (stream) {
const call = client.PredictWithStream(request, metadata)
const encoder = new TextEncoder()
const ReturnStream = new ReadableStream({
start(controller) {
// 处理数据
call.on('data', (response) => {
try {
let response_code = Number(response.response_code)
if (response_code === 204) {
controller.close()
call.destroy()
} else if (response_code === 200) {
let response_message
if (inputModel.includes('gpt')) {
response_message = response.body.message_warpper.message.message
} else {
response_message = response.args.args.args.message
}
controller.enqueue(
encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, OriginModel))}\n\n`),
)
} else {
console.error(`Invalid response code: ${response_code}`)
controller.error(error)
}
} catch (error) {
console.error('Error processing stream data:', error)
controller.error(error)
}
})
// 处理错误
call.on('error', (error) => {
console.error('Stream error:', error)
// 如果是 INTERNAL 错误且包含 RST_STREAM,可能是正常的流结束
if (error.code === 13 && error.details.includes('RST_STREAM')) {
controller.close()
} else {
controller.error(error)
}
call.destroy()
})
// 处理结束
call.on('end', () => {
controller.close()
})
// 处理取消
return () => {
call.destroy()
}
},
})
return new Response(ReturnStream, {
headers: {
'Content-Type': 'text/event-stream',
Connection: 'keep-alive',
'Cache-Control': 'no-cache',
'Transfer-Encoding': 'chunked',
},
})
} else {
// 非流式调用保持不变
const call = await new Promise((resolve, reject) => {
client.Predict(request, metadata, (err, response) => {
if (err) reject(err)
else resolve(response)
})
})
let response_code = Number(call.response_code)
if (response_code === 200) {
let response_message
if (inputModel.includes('gpt')) {
response_message = call.body.message_warpper.message.message
} else {
response_message = call.args.args.args.message
}
return new Response(JSON.stringify(ChatCompletionWithModel(response_message, OriginModel)), {
headers: {
'Content-Type': 'application/json',
},
})
} else {
throw new Error(`Invalid response code: ${response_code}`)
}
}
} catch (err) {
console.error(`Attempt ${i + 1} failed:`, err)
await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY))
}
}
return new Response(
JSON.stringify({
error: {
message: 'An error occurred while processing your request',
type: 'server_error',
code: 'internal_error',
param: null,
},
}),
{
status: 500,
headers: {
'Content-Type': 'application/json',
},
},
)
}
function ChatCompletionWithModel(message, model) {
return {
id: 'Chat-Nekohy',
object: 'chat.completion',
created: Date.now(),
model,
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
choices: [
{
message: {
content: message,
role: 'assistant',
},
index: 0,
},
],
}
}
function ChatCompletionStreamWithModel(text, model) {
return {
id: 'chatcmpl-Nekohy',
object: 'chat.completion.chunk',
created: 0,
model,
choices: [
{
index: 0,
delta: {
content: text,
},
finish_reason: null,
},
],
}
}
async function handleCompletion(request) {
try {
// 解析openai格式API请求
const { model: OriginModel, messages, stream, temperature, top_p } = await request.json()
const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/
const matchInput = OriginModel.match(RegexInput)
const inputModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : OriginModel
// 添加模型验证
if (!config.isValidModel(inputModel)) {
return new Response(
JSON.stringify({
error: {
message: `Model '${OriginModel}' does not exist`,
type: 'invalid_request_error',
param: 'model',
code: 'model_not_found',
},
}),
{
status: 404,
headers: {
'Content-Type': 'application/json',
},
},
)
}
console.log(inputModel, messages, stream)
// 解析system和user/assistant消息
const { rules, message: content } = await messagesProcess(messages)
console.log(rules, content)
// 响应码,回复的消息
return await GrpcToPieces(inputModel, OriginModel, content, rules, stream, temperature, top_p)
} catch (err) {
return error(500, err.message)
}
}
;(async () => {
//For Cloudflare Workers
if (typeof addEventListener === 'function') return
// For Nodejs
const ittyServer = createServerAdapter(router.fetch)
console.log(`Listening on http://localhost:${config.PORT}`)
const httpServer = createServer(ittyServer)
httpServer.listen(config.PORT)
})()