smgc commited on
Commit
408b814
·
verified ·
1 Parent(s): 6303aa3

Create api/index.js

Browse files
Files changed (1) hide show
  1. api/index.js +299 -0
api/index.js ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import grpc from '@grpc/grpc-js';
2
+ import protoLoader from '@grpc/proto-loader';
3
+ import {AutoRouter, cors, error, json} from 'itty-router';
4
+ import dotenv from 'dotenv';
5
+ import path,{ dirname } from 'path';
6
+ import { fileURLToPath } from 'url';
7
+ import {createServerAdapter} from '@whatwg-node/server';
8
+ import {createServer} from 'http';
9
+
10
+ // 加载环境变量
11
+ dotenv.config();
12
+ // 获取当前文件的目录路径(ESM 方式)
13
+ const __dirname = dirname(fileURLToPath(import.meta.url));
14
+ // 初始化配置
15
+ class Config {
16
+ constructor() {
17
+ this.API_PREFIX = process.env.API_PREFIX || '/';
18
+ this.API_KEY = process.env.API_KEY || '';
19
+ this.MAX_RETRY_COUNT = process.env.MAX_RETRY_COUNT || 3;
20
+ this.RETRY_DELAY = process.env.RETRY_DELAY || 5000;
21
+ this.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
22
+ this.COMMON_PROTO = path.join(__dirname,'..', 'protos', 'VertexInferenceService.proto')
23
+ this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
24
+ this.GPT_PROTO = path.join(__dirname,'..', 'protos', 'GPTInferenceService.proto')
25
+ this.PORT = process.env.PORT || 8787;
26
+ }
27
+ }
28
+ class GRPCHandler {
29
+ constructor(protoFilePath) {
30
+ // 动态加载传入的 .proto 文件路径
31
+ this.packageDefinition = protoLoader.loadSync(protoFilePath, {
32
+ keepCase: true,
33
+ longs: String,
34
+ enums: String,
35
+ defaults: true,
36
+ oneofs: true
37
+ });
38
+ }
39
+ }
40
+ const config = new Config();
41
+ // 中间件
42
+ // 添加运行回源
43
+ const { preflight, corsify } = cors({
44
+ origin: '*',
45
+ allowMethods: '*',
46
+ exposeHeaders: '*',
47
+ });
48
+
49
+ // 添加认证
50
+ const withAuth = (request) => {
51
+ if (config.API_KEY) {
52
+ const authHeader = request.headers.get('Authorization');
53
+ if (!authHeader || !authHeader.startsWith('Bearer ')) {
54
+ return error(401, 'Unauthorized: Missing or invalid Authorization header');
55
+ }
56
+ const token = authHeader.substring(7);
57
+ if (token !== config.API_KEY) {
58
+ return error(403, 'Forbidden: Invalid API key');
59
+ }
60
+ }
61
+ };
62
+ // 返回运行信息
63
+ const logger = (res, req) => {
64
+ console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms');
65
+ };
66
+ const router = AutoRouter({
67
+ before: [preflight, withAuth],
68
+ missing: () => error(404, '404 not found.'),
69
+ finally: [corsify, logger],
70
+ });
71
+ // Router路径
72
+ router.get('/', () => json({ message: 'API 服务运行中~' }));
73
+ router.get('/ping', () => json({ message: 'pong' }));
74
+ router.get(config.API_PREFIX + '/v1/models', () =>
75
+ json({
76
+ object: 'list',
77
+ data: [
78
+ { id: "gpt-4o-mini", object: "model", owned_by: "pieces-os" },
79
+ { id: "gpt-4o", object: "model", owned_by: "pieces-os" },
80
+ { id: "gpt-4-turbo", object: "model", owned_by: "pieces-os" },
81
+ { id: "gpt-4", object: "model", owned_by: "pieces-os" },
82
+ { id: "gpt-3.5-turbo", object: "model", owned_by: "pieces-os" },
83
+ { id: "claude-3-sonnet@20240229", object: "model", owned_by: "pieces-os" },
84
+ { id: "claude-3-opus@20240229", object: "model", owned_by: "pieces-os" },
85
+ { id: "claude-3-haiku@20240307", object: "model", owned_by: "pieces-os" },
86
+ { id: "claude-3-5-sonnet@20240620", object: "model", owned_by: "pieces-os" },
87
+ { id: "gemini-1.5-flash", object: "model", owned_by: "pieces-os" },
88
+ { id: "gemini-1.5-pro", object: "model", owned_by: "pieces-os" },
89
+ { id: "chat-bison", object: "model", owned_by: "pieces-os" },
90
+ { id: "codechat-bison", object: "model", owned_by: "pieces-os" },
91
+ ],
92
+ })
93
+ );
94
+ router.post(config.API_PREFIX + '/v1/chat/completions', (req) => handleCompletion(req));
95
+
96
+ async function GrpcToPieces(models, message, rules, stream, temperature, top_p) {
97
+ // 在非GPT类型的模型中,temperature和top_p是无效的
98
+ // 使用系统的根证书
99
+ const credentials = grpc.credentials.createSsl();
100
+ let client,request;
101
+ if (models.includes('gpt')){
102
+ // 加载proto文件
103
+ const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition;
104
+ // 构建请求消息
105
+ request = {
106
+ models: models,
107
+ messages: [
108
+ {role: 0, message: rules}, // system
109
+ {role: 1, message: message} // user
110
+ ],
111
+ temperature:temperature || 0.1,
112
+ top_p:top_p ?? 1,
113
+ }
114
+ // 获取gRPC对象
115
+ const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt;
116
+ client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials);
117
+ } else {
118
+ // 加载proto文件
119
+ const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition;
120
+ // 构建请求消息
121
+ request = {
122
+ models: models,
123
+ args: {
124
+ messages: {
125
+ unknown: 1,
126
+ message: message
127
+ },
128
+ rules: rules
129
+ }
130
+ };
131
+ // 获取gRPC对象
132
+ const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex;
133
+ client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials);
134
+ }
135
+ return await ConvertOpenai(client,request,models,stream);
136
+ }
137
+
138
+ async function messagesProcess(messages) {
139
+ let rules = '';
140
+ let message = '';
141
+
142
+ for (const msg of messages) {
143
+ let role = msg.role;
144
+ // 格式化为字符串
145
+ const contentStr = Array.isArray(msg.content)
146
+ ? msg.content
147
+ .filter((item) => item.text)
148
+ .map((item) => item.text)
149
+ .join('') || ''
150
+ : msg.content;
151
+ // 判断身份
152
+ if (role === 'system') {
153
+ rules += `system:${contentStr};\r\n`;
154
+ } else if (['user', 'assistant'].includes(role)) {
155
+ message += `${role}:${contentStr};\r\n`;
156
+ }
157
+ }
158
+
159
+ return { rules, message };
160
+ }
161
+
162
+ async function ConvertOpenai(client,request,model,stream) {
163
+ for (let i = 0; i < config.MAX_RETRY_COUNT; i++) {
164
+ try {
165
+ if (stream) {
166
+ const call = client.PredictWithStream(request);
167
+ const encoder = new TextEncoder();
168
+ const ReturnStream = new ReadableStream({
169
+ start(controller) {
170
+ call.on('data', (response) => {
171
+ let response_code = Number(response.response_code);
172
+ if (response_code === 204) {
173
+ // 如果 response_code 是 204,关闭流
174
+ controller.close()
175
+ call.destroy()
176
+ } else if (response_code === 200) {
177
+ let response_message
178
+ if (model.includes('gpt')) {
179
+ response_message = response.body.message_warpper.message.message;
180
+ } else {
181
+ response_message = response.args.args.args.message;
182
+ }
183
+ // 否则,将数据块加入流中
184
+ controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, model))}\n\n`));
185
+ } else {
186
+ controller.error(new Error(`Error: stream chunk is not success`));
187
+ controller.close()
188
+ }
189
+ })
190
+ }
191
+ });
192
+ return new Response(ReturnStream, {
193
+ headers: {
194
+ 'Content-Type': 'text/event-stream',
195
+ },
196
+ })
197
+ } else {
198
+ const call = await new Promise((resolve, reject) => {
199
+ client.Predict(request, (err, response) => {
200
+ if (err) reject(err);
201
+ else resolve(response);
202
+ });
203
+ });
204
+ let response_code = Number(call.response_code);
205
+ if (response_code === 200) {
206
+ let response_message
207
+ if (model.includes('gpt')) {
208
+ response_message = call.body.message_warpper.message.message;
209
+ } else {
210
+ response_message = call.args.args.args.message;
211
+ }
212
+ return new Response(JSON.stringify(ChatCompletionWithModel(response_message, model)), {
213
+ headers: {
214
+ 'Content-Type': 'application/json',
215
+ },
216
+ });
217
+ }
218
+ }
219
+ } catch (err) {
220
+ console.error(err);
221
+ await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY));
222
+ }
223
+ }
224
+ return error(500, err.message);
225
+ }
226
+
227
+ function ChatCompletionWithModel(message, model) {
228
+ return {
229
+ id: 'Chat-Nekohy',
230
+ object: 'chat.completion',
231
+ created: Date.now(),
232
+ model,
233
+ usage: {
234
+ prompt_tokens: 0,
235
+ completion_tokens: 0,
236
+ total_tokens: 0,
237
+ },
238
+ choices: [
239
+ {
240
+ message: {
241
+ content: message,
242
+ role: 'assistant',
243
+ },
244
+ index: 0,
245
+ },
246
+ ],
247
+ };
248
+ }
249
+
250
+ function ChatCompletionStreamWithModel(text, model) {
251
+ return {
252
+ id: 'chatcmpl-Nekohy',
253
+ object: 'chat.completion.chunk',
254
+ created: 0,
255
+ model,
256
+ choices: [
257
+ {
258
+ index: 0,
259
+ delta: {
260
+ content: text,
261
+ },
262
+ finish_reason: null,
263
+ },
264
+ ],
265
+ };
266
+ }
267
+
268
+ async function handleCompletion(request) {
269
+ try {
270
+ // todo stream逆向接口
271
+ // 解析openai格式API请求
272
+ const { model: inputModel, messages, stream,temperature,top_p} = await request.json();
273
+ console.log(inputModel,messages,stream)
274
+ // 解析system和user/assistant消息
275
+ const { rules, message:content } = await messagesProcess(messages);
276
+ console.log(rules,content)
277
+ // 响应码,回复的消息
278
+ return await GrpcToPieces(inputModel, content, rules, stream, temperature, top_p);
279
+ } catch (err) {
280
+ return error(500, err.message);
281
+ }
282
+ }
283
+
284
+ (async () => {
285
+ if (typeof addEventListener === 'function') return;
286
+
287
+ const ittyServer = createServerAdapter(router.fetch);
288
+ const httpServer = createServer(ittyServer);
289
+
290
+ // 添加错误处理
291
+ httpServer.on('error', (error) => {
292
+ console.error('Server error:', error);
293
+ });
294
+
295
+ // 确保监听所有接口
296
+ httpServer.listen(config.PORT, '0.0.0.0', () => {
297
+ console.log(`Server is running on http://0.0.0.0:${config.PORT}`);
298
+ });
299
+ })();