smgc commited on
Commit
5f466a2
·
verified ·
1 Parent(s): ae00909

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -478
app.py DELETED
@@ -1,478 +0,0 @@
1
- import os
2
- import time
3
- import asyncio
4
- import grpc
5
- from dotenv import load_dotenv
6
- from fastapi import FastAPI, HTTPException, Depends
7
- from fastapi.middleware.cors import CORSMiddleware
8
- from fastapi.responses import StreamingResponse
9
- from pydantic import BaseModel
10
- from typing import List, Optional
11
- from google.protobuf import descriptor_pb2
12
- from google.protobuf import descriptor
13
- from google.protobuf import descriptor_pool
14
- from google.protobuf import symbol_database
15
- from google.protobuf.compiler import plugin_pb2
16
- from google.protobuf.json_format import MessageToDict
17
-
18
- # 加载环境变量
19
- load_dotenv()
20
-
21
- # 内嵌 Proto 定义
22
- GPT_PROTO = """
23
- syntax = "proto3";
24
- package runtime.aot.machine_learning.parents.gpt;
25
-
26
- service GPTInferenceService {
27
- rpc Predict (Request) returns (Response);
28
- rpc PredictWithStream (Request) returns (stream Response);
29
- }
30
-
31
- message Message {
32
- uint64 role = 1;
33
- string message = 2;
34
- }
35
-
36
- message Request {
37
- string models = 1;
38
- repeated Message messages = 2;
39
- double temperature = 3;
40
- double top_p = 4;
41
- }
42
-
43
- message Response {
44
- uint64 response_code = 2;
45
- optional Body body = 4;
46
- }
47
-
48
- message Body {
49
- string id = 1;
50
- string object = 2;
51
- uint64 time = 3;
52
- MessageWarpper message_warpper = 4;
53
- Unknown unknown = 5;
54
- }
55
-
56
- message MessageWarpper {
57
- int64 arg1 = 1;
58
- Message message = 2;
59
- }
60
-
61
- message Unknown {
62
- int64 arg1 = 1;
63
- int64 arg2 = 2;
64
- int64 arg3 = 3;
65
- }
66
- """
67
-
68
- VERTEX_PROTO = """
69
- syntax = "proto3";
70
- package runtime.aot.machine_learning.parents.vertex;
71
-
72
- service VertexInferenceService {
73
- rpc Predict (Requests) returns (Response);
74
- rpc PredictWithStream (Requests) returns (stream Response);
75
- }
76
-
77
- message Messages {
78
- int64 unknown = 1;
79
- string message = 2;
80
- }
81
-
82
- message Requests {
83
- string models = 1;
84
- Args args = 2;
85
- }
86
-
87
- message Args {
88
- Messages messages = 2;
89
- string rules = 3;
90
- }
91
-
92
- message Response {
93
- int64 response_code = 2;
94
- Args1 args = 4;
95
- }
96
-
97
- message Args1 {
98
- Args2 args = 1;
99
- }
100
-
101
- message Args2 {
102
- Messages args = 2;
103
- }
104
- """
105
-
106
- # 配置类
107
- class Config:
108
- API_PREFIX = os.getenv('API_PREFIX', '/')
109
- API_KEY = os.getenv('API_KEY', '')
110
- MAX_RETRY_COUNT = int(os.getenv('MAX_RETRY_COUNT', 3))
111
- RETRY_DELAY = int(os.getenv('RETRY_DELAY', 5000))
112
- COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'
113
- GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'
114
- PORT = int(os.getenv('PORT', 8787))
115
-
116
- config = Config()
117
-
118
- # 动态生成 Proto
119
- def generate_proto_classes():
120
- pool = descriptor_pool.DescriptorPool()
121
-
122
- # 为 GPT 服务创建文件描述符
123
- gpt_file = descriptor_pb2.FileDescriptorProto()
124
- gpt_file.name = "gpt_service.proto"
125
- gpt_file.package = "runtime.aot.machine_learning.parents.gpt"
126
- gpt_file.syntax = "proto3"
127
-
128
- # GPT 服务消息定义
129
- message = gpt_file.message_type.add()
130
- message.name = "Message"
131
- field = message.field.add()
132
- field.name = "role"
133
- field.number = 1
134
- field.type = descriptor.FieldDescriptor.TYPE_UINT64
135
- field = message.field.add()
136
- field.name = "message"
137
- field.number = 2
138
- field.type = descriptor.FieldDescriptor.TYPE_STRING
139
-
140
- # Request 消息
141
- request = gpt_file.message_type.add()
142
- request.name = "Request"
143
- field = request.field.add()
144
- field.name = "models"
145
- field.number = 1
146
- field.type = descriptor.FieldDescriptor.TYPE_STRING
147
- field = request.field.add()
148
- field.name = "messages"
149
- field.number = 2
150
- field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
151
- field.type_name = ".runtime.aot.machine_learning.parents.gpt.Message"
152
- field.label = descriptor.FieldDescriptor.LABEL_REPEATED
153
- field = request.field.add()
154
- field.name = "temperature"
155
- field.number = 3
156
- field.type = descriptor.FieldDescriptor.TYPE_DOUBLE
157
- field = request.field.add()
158
- field.name = "top_p"
159
- field.number = 4
160
- field.type = descriptor.FieldDescriptor.TYPE_DOUBLE
161
-
162
- # Response 消息
163
- response = gpt_file.message_type.add()
164
- response.name = "Response"
165
- field = response.field.add()
166
- field.name = "response_code"
167
- field.number = 2
168
- field.type = descriptor.FieldDescriptor.TYPE_UINT64
169
- field = response.field.add()
170
- field.name = "body"
171
- field.number = 4
172
- field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
173
- field.type_name = ".runtime.aot.machine_learning.parents.gpt.Body"
174
- field.label = descriptor.FieldDescriptor.LABEL_OPTIONAL
175
-
176
- # Body 消息
177
- body = gpt_file.message_type.add()
178
- body.name = "Body"
179
- field = body.field.add()
180
- field.name = "id"
181
- field.number = 1
182
- field.type = descriptor.FieldDescriptor.TYPE_STRING
183
- field = body.field.add()
184
- field.name = "object"
185
- field.number = 2
186
- field.type = descriptor.FieldDescriptor.TYPE_STRING
187
- field = body.field.add()
188
- field.name = "time"
189
- field.number = 3
190
- field.type = descriptor.FieldDescriptor.TYPE_UINT64
191
- field = body.field.add()
192
- field.name = "message_warpper"
193
- field.number = 4
194
- field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
195
- field.type_name = ".runtime.aot.machine_learning.parents.gpt.MessageWarpper"
196
-
197
- # MessageWarpper 消息
198
- message_wrapper = gpt_file.message_type.add()
199
- message_wrapper.name = "MessageWarpper"
200
- field = message_wrapper.field.add()
201
- field.name = "arg1"
202
- field.number = 1
203
- field.type = descriptor.FieldDescriptor.TYPE_INT64
204
- field = message_wrapper.field.add()
205
- field.name = "message"
206
- field.number = 2
207
- field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
208
- field.type_name = ".runtime.aot.machine_learning.parents.gpt.Message"
209
-
210
- # GPT 服务定义
211
- service = gpt_file.service.add()
212
- service.name = "GPTInferenceService"
213
- method = service.method.add()
214
- method.name = "Predict"
215
- method.input_type = ".runtime.aot.machine_learning.parents.gpt.Request"
216
- method.output_type = ".runtime.aot.machine_learning.parents.gpt.Response"
217
- method = service.method.add()
218
- method.name = "PredictWithStream"
219
- method.input_type = ".runtime.aot.machine_learning.parents.gpt.Request"
220
- method.output_type = ".runtime.aot.machine_learning.parents.gpt.Response"
221
- method.server_streaming = True
222
-
223
- # 将文件描述符添加到池中
224
- pool.Add(gpt_file)
225
-
226
- # Vertex 服务的定义类似...
227
- # 这里省略 Vertex 服务的定义,原理相同
228
-
229
- return pool
230
-
231
- # 生成 Proto 类
232
- proto_pool = generate_proto_classes()
233
-
234
- # FastAPI 应用
235
- app = FastAPI()
236
-
237
- # CORS 中间件
238
- app.add_middleware(
239
- CORSMiddleware,
240
- allow_origins=["*"],
241
- allow_credentials=True,
242
- allow_methods=["*"],
243
- allow_headers=["*"],
244
- )
245
-
246
- # 设置 API 前缀
247
- API_PREFIX = "/ai"
248
-
249
- # 创建一个路由器
250
- router = APIRouter(prefix=API_PREFIX)
251
-
252
- # 认证依赖
253
- def verify_api_key(authorization: str = None):
254
- if config.API_KEY:
255
- if not authorization or not authorization.startswith('Bearer '):
256
- raise HTTPException(status_code=401, detail='Unauthorized: Missing or invalid Authorization header')
257
- token = authorization.split(' ')[1]
258
- if token != config.API_KEY:
259
- raise HTTPException(status_code=403, detail='Forbidden: Invalid API key')
260
-
261
- # 模型列表
262
- MODELS = [
263
- {"id": "gpt-4o-mini", "object": "model", "owned_by": "pieces-os"},
264
- {"id": "gpt-4o", "object": "model", "owned_by": "pieces-os"},
265
- {"id": "gpt-4-turbo", "object": "model", "owned_by": "pieces-os"},
266
- {"id": "gpt-4", "object": "model", "owned_by": "pieces-os"},
267
- {"id": "gpt-3.5-turbo", "object": "model", "owned_by": "pieces-os"},
268
- {"id": "claude-3-sonnet@20240229", "object": "model", "owned_by": "pieces-os"},
269
- {"id": "claude-3-opus@20240229", "object": "model", "owned_by": "pieces-os"},
270
- {"id": "claude-3-haiku@20240307", "object": "model", "owned_by": "pieces-os"},
271
- {"id": "claude-3-5-sonnet@20240620", "object": "model", "owned_by": "pieces-os"},
272
- {"id": "gemini-1.5-flash", "object": "model", "owned_by": "pieces-os"},
273
- {"id": "gemini-1.5-pro", "object": "model", "owned_by": "pieces-os"},
274
- {"id": "chat-bison", "object": "model", "owned_by": "pieces-os"},
275
- {"id": "codechat-bison", "object": "model", "owned_by": "pieces-os"},
276
- ]
277
-
278
- # API 路由
279
- @app.get("/")
280
- async def root():
281
- return {"message": "API 服务运行中~"}
282
-
283
- @app.get("/ping")
284
- async def ping():
285
- return {"message": "pong"}
286
-
287
- @app.get(f"{config.API_PREFIX}/v1/models")
288
- async def get_models():
289
- return {"object": "list", "data": MODELS}
290
-
291
- # 请求模型
292
- class Message(BaseModel):
293
- role: str
294
- content: str
295
-
296
- class ChatCompletionRequest(BaseModel):
297
- model: str
298
- messages: List[Message]
299
- stream: Optional[bool] = False
300
- temperature: Optional[float] = 0.1
301
- top_p: Optional[float] = 1.0
302
-
303
- @router.post("/v1/chat/completions")
304
- async def chat_completions(request: ChatCompletionRequest):
305
- try:
306
- rules, content = process_messages(request.messages)
307
- return await grpc_to_pieces(request.model, content, rules, request.stream, request.temperature, request.top_p)
308
- except Exception as e:
309
- raise HTTPException(status_code=500, detail=str(e))
310
-
311
- app.include_router(router)
312
-
313
- def process_messages(messages):
314
- rules = ''
315
- content = ''
316
- for msg in messages:
317
- if msg.role == 'system':
318
- rules += f"system:{msg.content};\r\n"
319
- elif msg.role in ['user', 'assistant']:
320
- content += f"{msg.role}:{msg.content};\r\n"
321
- return rules, content
322
-
323
- async def grpc_to_pieces(model, message, rules, stream, temperature, top_p):
324
- credentials = grpc.ssl_channel_credentials()
325
-
326
- try:
327
- if 'gpt' in model:
328
- channel = grpc.secure_channel(config.GPT_GRPC, credentials)
329
- stub = GPTInferenceServiceStub(channel)
330
-
331
- # 创建 GPT 请求
332
- request = {
333
- 'models': model,
334
- 'messages': [
335
- {'role': 0, 'message': rules},
336
- {'role': 1, 'message': message}
337
- ],
338
- 'temperature': temperature,
339
- 'top_p': top_p
340
- }
341
- else:
342
- channel = grpc.secure_channel(config.COMMON_GRPC, credentials)
343
- stub = VertexInferenceServiceStub(channel)
344
-
345
- # 创建 Vertex 请求
346
- request = {
347
- 'models': model,
348
- 'args': {
349
- 'messages': {'unknown': 1, 'message': message},
350
- 'rules': rules
351
- }
352
- }
353
-
354
- for _ in range(config.MAX_RETRY_COUNT):
355
- try:
356
- if stream:
357
- return await stream_response(stub, request, model)
358
- else:
359
- return await single_response(stub, request, model)
360
- except Exception as e:
361
- print(f"Error: {e}")
362
- await asyncio.sleep(config.RETRY_DELAY / 1000)
363
- continue
364
-
365
- raise HTTPException(status_code=500, detail="Max retry count reached")
366
-
367
- finally:
368
- channel.close()
369
-
370
- async def stream_response(stub, request, model):
371
- async def generate():
372
- try:
373
- responses = stub.PredictWithStream(request)
374
- for response in responses:
375
- response_code = response.response_code
376
- if response_code == 204:
377
- break
378
- elif response_code == 200:
379
- if 'gpt' in model:
380
- message = response.body.message_warpper.message.message
381
- else:
382
- message = response.args.args.args.message
383
-
384
- chunk = {
385
- "id": "chatcmpl-Nekohy",
386
- "object": "chat.completion.chunk",
387
- "created": 0,
388
- "model": model,
389
- "choices": [{
390
- "index": 0,
391
- "delta": {
392
- "content": message,
393
- },
394
- "finish_reason": None,
395
- }],
396
- }
397
- yield f"data: {json.dumps(chunk)}\n\n"
398
- except Exception as e:
399
- raise HTTPException(status_code=500, detail=str(e))
400
-
401
- yield "data: [DONE]\n\n"
402
-
403
- return StreamingResponse(
404
- generate(),
405
- media_type="text/event-stream",
406
- headers={
407
- "Cache-Control": "no-cache",
408
- "Connection": "keep-alive",
409
- }
410
- )
411
-
412
- async def single_response(stub, request, model):
413
- response = await asyncio.get_event_loop().run_in_executor(
414
- None, stub.Predict, request
415
- )
416
-
417
- if response.response_code == 200:
418
- if 'gpt' in model:
419
- message = response.body.message_warpper.message.message
420
- else:
421
- message = response.args.args.args.message
422
-
423
- return {
424
- "id": "Chat-Nekohy",
425
- "object": "chat.completion",
426
- "created": int(time.time()),
427
- "model": model,
428
- "usage": {
429
- "prompt_tokens": 0,
430
- "completion_tokens": 0,
431
- "total_tokens": 0,
432
- },
433
- "choices": [{
434
- "message": {
435
- "content": message,
436
- "role": "assistant",
437
- },
438
- "index": 0,
439
- }],
440
- }
441
- else:
442
- raise HTTPException(
443
- status_code=500,
444
- detail=f"Error response code: {response.response_code}"
445
- )
446
-
447
- # gRPC 服务存根类
448
- class GPTInferenceServiceStub:
449
- def __init__(self, channel):
450
- self.channel = channel
451
- self.stub = grpc.ProtoRPC(channel)
452
-
453
- def Predict(self, request):
454
- return self.stub.Predict(request)
455
-
456
- def PredictWithStream(self, request):
457
- return self.stub.PredictWithStream(request)
458
-
459
- class VertexInferenceServiceStub:
460
- def __init__(self, channel):
461
- self.channel = channel
462
- self.stub = grpc.ProtoRPC(channel)
463
-
464
- def Predict(self, request):
465
- return self.stub.Predict(request)
466
-
467
- def PredictWithStream(self, request):
468
- return self.stub.PredictWithStream(request)
469
-
470
- if __name__ == "__main__":
471
- import uvicorn
472
- uvicorn.run(
473
- app,
474
- host="0.0.0.0",
475
- port=config.PORT,
476
- log_level="info"
477
- )
478
-