smgc commited on
Commit
8391cb0
·
verified ·
1 Parent(s): c320fe9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import grpc
5
+ import asyncio
6
+ from typing import List, Optional
7
+ from fastapi import FastAPI, HTTPException, Request
8
+ from fastapi.responses import JSONResponse, StreamingResponse
9
+ from pydantic import BaseModel
10
+ from dotenv import load_dotenv
11
+ from grpc_tools import protoc
12
+ import re
13
+
14
+ # 加载环境变量
15
+ load_dotenv()
16
+
17
+ # 配置类
18
+ class Config:
19
+ def __init__(self):
20
+ self.API_PREFIX = os.getenv('API_PREFIX', '/')
21
+ self.API_KEY = os.getenv('API_KEY', '')
22
+ self.MAX_RETRY_COUNT = int(os.getenv('MAX_RETRY_COUNT', 3))
23
+ self.RETRY_DELAY = int(os.getenv('RETRY_DELAY', 5000))
24
+ self.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'
25
+ self.COMMON_PROTO = 'protos/VertexInferenceService.proto'
26
+ self.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app'
27
+ self.GPT_PROTO = 'protos/GPTInferenceService.proto'
28
+ self.PORT = int(os.getenv('PORT', 8787))
29
+ self.SUPPORTED_MODELS = [
30
+ "gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo",
31
+ "claude-3-sonnet@20240229", "claude-3-opus@20240229", "claude-3-haiku@20240307",
32
+ "claude-3-5-sonnet@20240620", "gemini-1.5-flash", "gemini-1.5-pro",
33
+ "chat-bison", "codechat-bison"
34
+ ]
35
+
36
+ def is_valid_model(self, model):
37
+ regex_input = r'^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$'
38
+ match_input = re.match(regex_input, model)
39
+ normalized_model = f"{match_input.group(1)}@{match_input.group(3)}" if match_input else model
40
+ return normalized_model in self.SUPPORTED_MODELS
41
+
42
+ # gRPC处理类
43
+ class GRPCHandler:
44
+ def __init__(self, proto_file):
45
+ self.proto_file = proto_file
46
+ self._compile_proto()
47
+ self._load_proto()
48
+
49
+ def _compile_proto(self):
50
+ proto_dir = os.path.dirname(self.proto_file)
51
+ proto_file = os.path.basename(self.proto_file)
52
+ protoc.main((
53
+ '',
54
+ f'-I{proto_dir}',
55
+ f'--python_out=.',
56
+ f'--grpc_python_out=.',
57
+ os.path.join(proto_dir, proto_file)
58
+ ))
59
+
60
+ def _load_proto(self):
61
+ module_name = os.path.splitext(os.path.basename(self.proto_file))[0] + '_pb2_grpc'
62
+ proto_module = __import__(module_name)
63
+ self.stub_class = getattr(proto_module, f"{module_name.split('_')[0]}Stub")
64
+
65
+ async def grpc_to_pieces(self, model, content, rules, temperature, top_p):
66
+ channel = grpc.aio.secure_channel(
67
+ config.COMMON_GRPC if not model.startswith('gpt') else config.GPT_GRPC,
68
+ grpc.ssl_channel_credentials()
69
+ )
70
+ stub = self.stub_class(channel)
71
+
72
+ try:
73
+ request = self._build_request(model, content, rules, temperature, top_p)
74
+ response = await stub.Predict(request)
75
+ return self._process_response(response, model)
76
+ except grpc.RpcError as e:
77
+ print(f"RPC failed: {e}")
78
+ return {"error": str(e)}
79
+ finally:
80
+ await channel.close()
81
+
82
+ async def grpc_to_pieces_stream(self, model, content, rules, temperature, top_p):
83
+ channel = grpc.aio.secure_channel(
84
+ config.COMMON_GRPC if not model.startswith('gpt') else config.GPT_GRPC,
85
+ grpc.ssl_channel_credentials()
86
+ )
87
+ stub = self.stub_class(channel)
88
+
89
+ try:
90
+ request = self._build_request(model, content, rules, temperature, top_p)
91
+ async for response in stub.PredictWithStream(request):
92
+ result = self._process_stream_response(response, model)
93
+ if result:
94
+ yield f"data: {json.dumps(result)}\n\n"
95
+ except grpc.RpcError as e:
96
+ print(f"Stream RPC failed: {e}")
97
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
98
+ finally:
99
+ await channel.close()
100
+
101
+ def _build_request(self, model, content, rules, temperature, top_p):
102
+ if model.startswith('gpt'):
103
+ return self.stub_class.Request(
104
+ models=model,
105
+ messages=[
106
+ {"role": 0, "message": rules},
107
+ {"role": 1, "message": content}
108
+ ],
109
+ temperature=temperature or 0.1,
110
+ top_p=top_p or 1.0
111
+ )
112
+ else:
113
+ return self.stub_class.Request(
114
+ models=model,
115
+ args={
116
+ "messages": {
117
+ "unknown": 1,
118
+ "message": content
119
+ },
120
+ "rules": rules
121
+ }
122
+ )
123
+
124
+ def _process_response(self, response, model):
125
+ if response.response_code == 200:
126
+ if model.startswith('gpt'):
127
+ message = response.body.message_warpper.message.message
128
+ else:
129
+ message = response.args.args.args.message
130
+ return chat_completion_with_model(message, model)
131
+ return {"error": f"Invalid response code: {response.response_code}"}
132
+
133
+ def _process_stream_response(self, response, model):
134
+ if response.response_code == 204:
135
+ return None
136
+ elif response.response_code == 200:
137
+ if model.startswith('gpt'):
138
+ message = response.body.message_warpper.message.message
139
+ else:
140
+ message = response.args.args.args.message
141
+ return chat_completion_stream_with_model(message, model)
142
+ else:
143
+ return {"error": f"Invalid response code: {response.response_code}"}
144
+
145
+ # 工具函数
146
+ def messages_process(messages):
147
+ rules = ''
148
+ message = ''
149
+
150
+ for msg in messages:
151
+ role = msg.role
152
+ content = msg.content
153
+
154
+ if isinstance(content, list):
155
+ content = ''.join([item.get('text', '') for item in content if item.get('text')])
156
+
157
+ if role == 'system':
158
+ rules += f"system:{content};\r\n"
159
+ elif role in ['user', 'assistant']:
160
+ message += f"{role}:{content};\r\n"
161
+
162
+ return rules, message
163
+
164
+ def chat_completion_with_model(message: str, model: str):
165
+ return {
166
+ "id": "Chat-Nekohy",
167
+ "object": "chat.completion",
168
+ "created": int(time.time()),
169
+ "model": model,
170
+ "usage": {
171
+ "prompt_tokens": 0,
172
+ "completion_tokens": 0,
173
+ "total_tokens": 0,
174
+ },
175
+ "choices": [
176
+ {
177
+ "message": {
178
+ "content": message,
179
+ "role": "assistant",
180
+ },
181
+ "index": 0,
182
+ },
183
+ ],
184
+ }
185
+
186
+ def chat_completion_stream_with_model(text: str, model: str):
187
+ return {
188
+ "id": "chatcmpl-Nekohy",
189
+ "object": "chat.completion.chunk",
190
+ "created": 0,
191
+ "model": model,
192
+ "choices": [
193
+ {
194
+ "index": 0,
195
+ "delta": {
196
+ "content": text,
197
+ },
198
+ "finish_reason": None,
199
+ },
200
+ ],
201
+ }
202
+
203
+ # 初始化配置
204
+ config = Config()
205
+
206
+ # 初始化 FastAPI 应用
207
+ app = FastAPI()
208
+
209
+ # 定义请求模型
210
+ class ChatMessage(BaseModel):
211
+ role: str
212
+ content: str
213
+
214
+ class ChatCompletionRequest(BaseModel):
215
+ model: str
216
+ messages: List[ChatMessage]
217
+ stream: Optional[bool] = False
218
+ temperature: Optional[float] = None
219
+ top_p: Optional[float] = None
220
+
221
+ # 路由定义
222
+ @app.get("/")
223
+ async def root():
224
+ return {"message": "API 服务运行中~"}
225
+
226
+ @app.get("/ping")
227
+ async def ping():
228
+ return {"message": "pong"}
229
+
230
+ @app.get(config.API_PREFIX + "/v1/models")
231
+ async def list_models():
232
+ with open('cloud_model.json', 'r') as f:
233
+ cloud_models = json.load(f)
234
+
235
+ models = [
236
+ {"id": model["unique"], "object": "model", "owned_by": "pieces-os"}
237
+ for model in cloud_models["iterable"]
238
+ ]
239
+
240
+ return JSONResponse({
241
+ "object": "list",
242
+ "data": models
243
+ })
244
+
245
+ @app.post(config.API_PREFIX + "/v1/chat/completions")
246
+ async def chat_completions(request: ChatCompletionRequest):
247
+ if not config.is_valid_model(request.model):
248
+ raise HTTPException(status_code=404, detail=f"Model '{request.model}' does not exist")
249
+
250
+ rules, content = messages_process(request.messages)
251
+
252
+ grpc_handler = GRPCHandler(config.COMMON_PROTO if not request.model.startswith('gpt') else config.GPT_PROTO)
253
+
254
+ if request.stream:
255
+ return StreamingResponse(
256
+ grpc_handler.grpc_to_pieces_stream(
257
+ request.model, content, rules, request.temperature, request.top_p
258
+ ),
259
+ media_type="text/event-stream"
260
+ )
261
+ else:
262
+ response = await grpc_handler.grpc_to_pieces(
263
+ request.model, content, rules, request.temperature, request.top_p
264
+ )
265
+ return JSONResponse(content=response)
266
+
267
+ if __name__ == "__main__":
268
+ import uvicorn
269
+ uvicorn.run(app, host="0.0.0.0", port=config.PORT)