dfa32412 commited on
Commit
97565d3
·
verified ·
1 Parent(s): cb367ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -73
app.py CHANGED
@@ -30,6 +30,7 @@ logging.basicConfig(
30
  )
31
  logger = logging.getLogger(__name__)
32
 
 
33
  #################################################
34
  # 模型定义
35
  #################################################
@@ -41,6 +42,7 @@ class ChatMessage(BaseModel):
41
  content: Optional[str] = None
42
  name: Optional[str] = None
43
 
 
44
  class ChatCompletionRequest(BaseModel):
45
  """OpenAI聊天完成API请求模型"""
46
  model: str
@@ -54,6 +56,7 @@ class ChatCompletionRequest(BaseModel):
54
  frequency_penalty: Optional[float] = 0
55
  user: Optional[str] = None
56
 
 
57
  # OpenAI API 响应模型
58
  class ChatCompletionResponseChoice(BaseModel):
59
  """OpenAI聊天完成API响应中的单个选择"""
@@ -61,12 +64,14 @@ class ChatCompletionResponseChoice(BaseModel):
61
  message: ChatMessage
62
  finish_reason: Optional[str] = None
63
 
 
64
  class Usage(BaseModel):
65
  """OpenAI API响应中的token使用信息"""
66
  prompt_tokens: int
67
  completion_tokens: int
68
  total_tokens: int
69
 
 
70
  class ChatCompletionResponse(BaseModel):
71
  """OpenAI聊天完成API响应模型"""
72
  id: str
@@ -76,6 +81,7 @@ class ChatCompletionResponse(BaseModel):
76
  choices: List[ChatCompletionResponseChoice]
77
  usage: Usage
78
 
 
79
  # OpenAI API 流式响应模型
80
  class ChatCompletionStreamResponseChoice(BaseModel):
81
  """OpenAI聊天完成流式API响应中的单个选择"""
@@ -83,6 +89,7 @@ class ChatCompletionStreamResponseChoice(BaseModel):
83
  delta: Dict[str, Any]
84
  finish_reason: Optional[str] = None
85
 
 
86
  class ChatCompletionStreamResponse(BaseModel):
87
  """OpenAI聊天完成流式API响应模型"""
88
  id: str
@@ -91,6 +98,7 @@ class ChatCompletionStreamResponse(BaseModel):
91
  model: str
92
  choices: List[ChatCompletionStreamResponseChoice]
93
 
 
94
  # 模型信息响应
95
  class ModelInfo(BaseModel):
96
  """OpenAI模型信息"""
@@ -99,11 +107,13 @@ class ModelInfo(BaseModel):
99
  created: int
100
  owned_by: str = "augment"
101
 
 
102
  class ModelListResponse(BaseModel):
103
  """OpenAI模型列表响应"""
104
  object: str = "list"
105
  data: List[ModelInfo]
106
 
 
107
  # Augment API 请求相关模型
108
  class AugmentResponseNode(BaseModel):
109
  """Augment API响应节点"""
@@ -112,6 +122,7 @@ class AugmentResponseNode(BaseModel):
112
  content: str
113
  tool_use: Optional[Any] = None
114
 
 
115
  class AugmentChatHistoryItem(BaseModel):
116
  """Augment API聊天历史记录条目"""
117
  request_message: str
@@ -120,20 +131,24 @@ class AugmentChatHistoryItem(BaseModel):
120
  request_nodes: List[Any] = []
121
  response_nodes: List[AugmentResponseNode] = []
122
 
 
123
  class AugmentBlobs(BaseModel):
124
  """Augment API Blobs对象"""
125
  checkpoint_id: Optional[str] = None
126
  added_blobs: List[Any] = []
127
  deleted_blobs: List[Any] = []
128
 
 
129
  class AugmentVcsChange(BaseModel):
130
  """Augment API VCS更改"""
131
  working_directory_changes: List[Any] = []
132
 
 
133
  class AugmentFeatureFlags(BaseModel):
134
  """Augment API功能标志"""
135
  support_raw_output: bool = True
136
 
 
137
  # 完整的Augment API请求模型
138
  class AugmentChatRequest(BaseModel):
139
  """Augment API聊天请求模型 - 基于抓包分析更新"""
@@ -157,10 +172,11 @@ class AugmentChatRequest(BaseModel):
157
  feature_detection_flags: AugmentFeatureFlags = AugmentFeatureFlags()
158
  tool_definitions: List[Any] = []
159
  nodes: List[Any] = []
160
- mode: str = "CHAT"
161
  agent_memories: Optional[Any] = None
162
  system_prompt: Optional[str] = None # 保留此字段以兼容之前的代码
163
 
 
164
  # Augment API响应模型
165
  class AugmentResponseChunk(BaseModel):
166
  """Augment API响应块"""
@@ -171,6 +187,7 @@ class AugmentResponseChunk(BaseModel):
171
  incorporated_external_sources: List[Any] = []
172
  nodes: List[AugmentResponseNode] = []
173
 
 
174
  #################################################
175
  # 辅助函数
176
  #################################################
@@ -179,6 +196,7 @@ def generate_id():
179
  """生成唯一ID,类似于OpenAI的格式"""
180
  return str(uuid.uuid4()).replace("-", "")[:24]
181
 
 
182
  def estimate_tokens(text):
183
  """
184
  估计文本的token数量
@@ -192,31 +210,33 @@ def estimate_tokens(text):
192
  chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') if text else 0
193
  return int(words * 1.3 + chinese_chars)
194
 
 
195
  def convert_to_augment_request(openai_request: ChatCompletionRequest) -> AugmentChatRequest:
196
  """
197
  将OpenAI API请求转换为Augment API请求
198
-
199
  Args:
200
  openai_request: OpenAI API请求对象
201
-
202
  Returns:
203
  转换后的Augment API请求对象
204
-
205
  Raises:
206
  HTTPException: 如果请求格式无效
207
  """
208
  chat_history = []
209
  system_message = None
210
-
211
  # 处理消息历史记录
212
  for i in range(len(openai_request.messages) - 1):
213
  msg = openai_request.messages[i]
214
  if msg.role == "system":
215
  system_message = msg.content
216
- elif msg.role == "user" and i + 1 < len(openai_request.messages) and openai_request.messages[i + 1].role == "assistant":
 
217
  user_msg = msg.content
218
  assistant_msg = openai_request.messages[i + 1].content
219
-
220
  # 创建历史记录条目,格式符合Augment API
221
  history_item = AugmentChatHistoryItem(
222
  request_message=user_msg,
@@ -232,34 +252,35 @@ def convert_to_augment_request(openai_request: ChatCompletionRequest) -> Augment
232
  ]
233
  )
234
  chat_history.append(history_item)
235
-
236
  # 获取当前用户消息
237
  current_message = None
238
  for msg in reversed(openai_request.messages):
239
  if msg.role == "user":
240
  current_message = msg.content
241
  break
242
-
243
  # 如果没有用户消息,则返回错误
244
  if current_message is None:
245
  raise HTTPException(
246
  status_code=400,
247
  detail="At least one user message is required"
248
  )
249
-
250
  # 准备Augment请求体
251
  augment_request = AugmentChatRequest(
252
  message=current_message,
253
  chat_history=chat_history,
254
  mode="CHAT"
255
  )
256
-
257
  # 如果有系统消息,设置为用户指南
258
  if system_message:
259
  augment_request.user_guidelines = system_message
260
-
261
  return augment_request
262
 
 
263
  #################################################
264
  # FastAPI应用
265
  #################################################
@@ -267,12 +288,12 @@ def convert_to_augment_request(openai_request: ChatCompletionRequest) -> Augment
267
  def create_app(augment_base_url, chat_endpoint, timeout):
268
  """
269
  创建并配置FastAPI应用
270
-
271
  Args:
272
  augment_base_url: Augment API基础URL
273
  chat_endpoint: 聊天端点路径
274
  timeout: 请求超时时间
275
-
276
  Returns:
277
  配置好的FastAPI应用
278
  """
@@ -317,19 +338,19 @@ def create_app(augment_base_url, chat_endpoint, timeout):
317
  async def verify_api_key(authorization: str = Header(...)):
318
  """
319
  验证API密钥
320
-
321
  Args:
322
  authorization: Authorization头部值
323
-
324
  Returns:
325
  提取的API密钥
326
-
327
  Raises:
328
  HTTPException: 如果API密钥格式无效或为空
329
  """
330
  if not authorization.startswith("Bearer "):
331
  raise HTTPException(
332
- status_code=401,
333
  detail={
334
  "error": {
335
  "message": "Invalid API key format. Expected 'Bearer YOUR_API_KEY'",
@@ -381,16 +402,16 @@ def create_app(augment_base_url, chat_endpoint, timeout):
381
 
382
  @app.post("/v1/chat/completions")
383
  async def chat_completions(
384
- request: ChatCompletionRequest,
385
- api_key: str = Depends(verify_api_key)
386
  ):
387
  """
388
  聊天完成端点 - 将OpenAI API请求转换为Augment API请求
389
-
390
  Args:
391
  request: OpenAI格式的聊天完成请求
392
  api_key: 通过验证的API密钥
393
-
394
  Returns:
395
  OpenAI格式的聊天完成响应或流式响应
396
  """
@@ -399,7 +420,6 @@ def create_app(augment_base_url, chat_endpoint, timeout):
399
  augment_request = convert_to_augment_request(request)
400
  logger.debug(f"Converted request: {augment_request.dict()}")
401
 
402
-
403
  if ":" in api_key:
404
  tenant_id, api_key = api_key.split(":")
405
  augment_base_url = f"https://{tenant_id}.api.augmentcode.com/"
@@ -407,13 +427,15 @@ def create_app(augment_base_url, chat_endpoint, timeout):
407
  # 决定是否使用流式响应
408
  if request.stream:
409
  return StreamingResponse(
410
- stream_augment_response(augment_base_url, api_key, augment_request, request.model, chat_endpoint, timeout),
 
411
  media_type="text/event-stream"
412
  )
413
  else:
414
  # 同步请求处理
415
- return await handle_sync_request(augment_base_url, api_key, augment_request, request.model, chat_endpoint, timeout)
416
-
 
417
  except httpx.TimeoutException:
418
  logger.error("Request to Augment API timed out")
419
  raise HTTPException(
@@ -459,10 +481,11 @@ def create_app(augment_base_url, chat_endpoint, timeout):
459
 
460
  return app
461
 
 
462
  async def handle_sync_request(base_url, api_key, augment_request, model_name, chat_endpoint, timeout):
463
  """
464
  处理同步请求
465
-
466
  Args:
467
  base_url: Augment API基础URL
468
  api_key: API密钥
@@ -470,7 +493,7 @@ async def handle_sync_request(base_url, api_key, augment_request, model_name, ch
470
  model_name: 模型名称
471
  chat_endpoint: 聊天端点
472
  timeout: 请求超时时间
473
-
474
  Returns:
475
  OpenAI格式的聊天完成响应
476
  """
@@ -485,7 +508,7 @@ async def handle_sync_request(base_url, api_key, augment_request, model_name, ch
485
  "Accept": "*/*"
486
  }
487
  )
488
-
489
  if response.status_code != 200:
490
  logger.error(f"Augment API error: {response.status_code} - {response.text}")
491
  raise HTTPException(
@@ -499,7 +522,7 @@ async def handle_sync_request(base_url, api_key, augment_request, model_name, ch
499
  }
500
  }
501
  )
502
-
503
  # 处理流式响应,合并为完整响应
504
  full_response = ""
505
  for line in response.text.split("\n"):
@@ -510,11 +533,11 @@ async def handle_sync_request(base_url, api_key, augment_request, model_name, ch
510
  full_response += data["text"]
511
  except json.JSONDecodeError:
512
  logger.warning(f"Failed to parse JSON: {line}")
513
-
514
  # 估算token使用情况
515
  prompt_tokens = estimate_tokens(augment_request.message)
516
  completion_tokens = estimate_tokens(full_response)
517
-
518
  # 构建OpenAI格式响应
519
  return ChatCompletionResponse(
520
  id=f"chatcmpl-{generate_id()}",
@@ -537,10 +560,11 @@ async def handle_sync_request(base_url, api_key, augment_request, model_name, ch
537
  )
538
  )
539
 
 
540
  async def stream_augment_response(base_url, api_key, augment_request, model_name, chat_endpoint, timeout):
541
  """
542
  处理流式响应
543
-
544
  Args:
545
  base_url: Augment API基础URL
546
  api_key: API密钥
@@ -548,22 +572,22 @@ async def stream_augment_response(base_url, api_key, augment_request, model_name
548
  model_name: 模型名称
549
  chat_endpoint: 聊天端点
550
  timeout: 请求超时时间
551
-
552
  Yields:
553
  流式响应的数据块
554
  """
555
  async with httpx.AsyncClient(timeout=timeout) as client:
556
  try:
557
  async with client.stream(
558
- "POST",
559
- f"{base_url.rstrip('/')}/{chat_endpoint}",
560
- json=augment_request.dict(),
561
- headers={
562
- "Content-Type": "application/json",
563
- "Authorization": f"Bearer {api_key}",
564
- "User-Agent": "chrome",
565
- "Accept": "*/*"
566
- }
567
  ) as response:
568
 
569
  if response.status_code != 200:
@@ -572,11 +596,11 @@ async def stream_augment_response(base_url, api_key, augment_request, model_name
572
  error_message = f"Error from Augment API: {error_detail.decode('utf-8', errors='replace')}"
573
  yield f"data: {json.dumps({'error': error_message})}\n\n"
574
  return
575
-
576
  # 生成唯一ID
577
  chat_id = f"chatcmpl-{generate_id()}"
578
  created_time = int(time.time())
579
-
580
  # 初始化响应
581
  init_response = ChatCompletionStreamResponse(
582
  id=chat_id,
@@ -592,19 +616,19 @@ async def stream_augment_response(base_url, api_key, augment_request, model_name
592
  )
593
  init_data = json.dumps(init_response.dict())
594
  yield f"data: {init_data}\n\n"
595
-
596
  # 处理流式响应
597
  buffer = ""
598
  async for line in response.aiter_lines():
599
  if not line.strip():
600
  continue
601
-
602
  try:
603
  # 解析Augment响应格式
604
  chunk = json.loads(line)
605
  if "text" in chunk and chunk["text"]:
606
  content = chunk["text"]
607
-
608
  # 发送增量更新
609
  stream_response = ChatCompletionStreamResponse(
610
  id=chat_id,
@@ -622,7 +646,7 @@ async def stream_augment_response(base_url, api_key, augment_request, model_name
622
  yield f"data: {response_data}\n\n"
623
  except json.JSONDecodeError:
624
  logger.warning(f"Failed to parse JSON: {line}")
625
-
626
  # 发送完成信号
627
  final_response = ChatCompletionStreamResponse(
628
  id=chat_id,
@@ -638,10 +662,10 @@ async def stream_augment_response(base_url, api_key, augment_request, model_name
638
  )
639
  final_data = json.dumps(final_response.dict())
640
  yield f"data: {final_data}\n\n"
641
-
642
  # 发送[DONE]标记
643
  yield "data: [DONE]\n\n"
644
-
645
  except httpx.TimeoutException:
646
  logger.error("Request to Augment API timed out")
647
  yield f"data: {json.dumps({'error': 'Request to Augment API timed out'})}\n\n"
@@ -652,47 +676,48 @@ async def stream_augment_response(base_url, api_key, augment_request, model_name
652
  logger.exception("Unexpected error")
653
  yield f"data: {json.dumps({'error': f'Internal server error: {str(e)}'})}\n\n"
654
 
 
655
  def parse_args():
656
  """解析命令行参数"""
657
  parser = argparse.ArgumentParser(
658
  description="OpenAI to Augment API Adapter",
659
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
660
  )
661
-
662
  parser.add_argument(
663
- "--augment-url",
664
  default="https://d6.api.augmentcode.com/",
665
  help="Augment API基础URL"
666
  )
667
-
668
  parser.add_argument(
669
- "--chat-endpoint",
670
  default="chat-stream",
671
  help="Augment聊天端点路径"
672
  )
673
-
674
  parser.add_argument(
675
- "--host",
676
  default="0.0.0.0",
677
  help="服务器主机地址"
678
  )
679
-
680
  parser.add_argument(
681
- "--port",
682
- type=int,
683
  default=8686,
684
  help="服务器端口"
685
  )
686
-
687
  parser.add_argument(
688
- "--timeout",
689
- type=int,
690
  default=120,
691
  help="API请求超时时间(秒)"
692
  )
693
-
694
  parser.add_argument(
695
- "--debug",
696
  action="store_true",
697
  help="启用调试模式"
698
  )
@@ -702,9 +727,10 @@ def parse_args():
702
  default="d18",
703
  help="Augment API租户ID (域名前缀)"
704
  )
705
-
706
  return parser.parse_args()
707
 
 
708
  #################################################
709
  # 主程序
710
  #################################################
@@ -712,11 +738,11 @@ def parse_args():
712
  def main():
713
  """主函数"""
714
  args = parse_args()
715
-
716
  # 配置日志级别
717
  if args.debug:
718
  logging.getLogger().setLevel(logging.DEBUG)
719
-
720
  # 构建完整的Augment URL
721
  if args.augment_url == "https://d18.api.augmentcode.com/":
722
  # 如果使用默认URL,则应用tenant-id参数
@@ -725,25 +751,26 @@ def main():
725
  else:
726
  # 否则使用提供的URL
727
  augment_base_url = args.augment_url
728
-
729
  # 创建应用
730
  app = create_app(
731
  augment_base_url=augment_base_url,
732
  chat_endpoint=args.chat_endpoint,
733
  timeout=args.timeout
734
  )
735
-
736
  # 启动应用
737
  logger.info(f"Starting server on {args.host}:7860")
738
  logger.info(f"Using Augment base URL: {augment_base_url}")
739
  logger.info(f"Using Augment chat endpoint: {args.chat_endpoint}")
740
-
741
  uvicorn.run(
742
- app,
743
- host=args.host,
744
  port=7860,
745
  log_level="info" if not args.debug else "debug"
746
  )
747
 
 
748
  if __name__ == "__main__":
749
  main()
 
30
  )
31
  logger = logging.getLogger(__name__)
32
 
33
+
34
  #################################################
35
  # 模型定义
36
  #################################################
 
42
  content: Optional[str] = None
43
  name: Optional[str] = None
44
 
45
+
46
  class ChatCompletionRequest(BaseModel):
47
  """OpenAI聊天完成API请求模型"""
48
  model: str
 
56
  frequency_penalty: Optional[float] = 0
57
  user: Optional[str] = None
58
 
59
+
60
  # OpenAI API 响应模型
61
  class ChatCompletionResponseChoice(BaseModel):
62
  """OpenAI聊天完成API响应中的单个选择"""
 
64
  message: ChatMessage
65
  finish_reason: Optional[str] = None
66
 
67
+
68
  class Usage(BaseModel):
69
  """OpenAI API响应中的token使用信息"""
70
  prompt_tokens: int
71
  completion_tokens: int
72
  total_tokens: int
73
 
74
+
75
  class ChatCompletionResponse(BaseModel):
76
  """OpenAI聊天完成API响应模型"""
77
  id: str
 
81
  choices: List[ChatCompletionResponseChoice]
82
  usage: Usage
83
 
84
+
85
  # OpenAI API 流式响应模型
86
  class ChatCompletionStreamResponseChoice(BaseModel):
87
  """OpenAI聊天完成流式API响应中的单个选择"""
 
89
  delta: Dict[str, Any]
90
  finish_reason: Optional[str] = None
91
 
92
+
93
  class ChatCompletionStreamResponse(BaseModel):
94
  """OpenAI聊天完成流式API响应模型"""
95
  id: str
 
98
  model: str
99
  choices: List[ChatCompletionStreamResponseChoice]
100
 
101
+
102
  # 模型信息响应
103
  class ModelInfo(BaseModel):
104
  """OpenAI模型信息"""
 
107
  created: int
108
  owned_by: str = "augment"
109
 
110
+
111
  class ModelListResponse(BaseModel):
112
  """OpenAI模型列表响应"""
113
  object: str = "list"
114
  data: List[ModelInfo]
115
 
116
+
117
  # Augment API 请求相关模型
118
  class AugmentResponseNode(BaseModel):
119
  """Augment API响应节点"""
 
122
  content: str
123
  tool_use: Optional[Any] = None
124
 
125
+
126
  class AugmentChatHistoryItem(BaseModel):
127
  """Augment API聊天历史记录条目"""
128
  request_message: str
 
131
  request_nodes: List[Any] = []
132
  response_nodes: List[AugmentResponseNode] = []
133
 
134
+
135
  class AugmentBlobs(BaseModel):
136
  """Augment API Blobs对象"""
137
  checkpoint_id: Optional[str] = None
138
  added_blobs: List[Any] = []
139
  deleted_blobs: List[Any] = []
140
 
141
+
142
  class AugmentVcsChange(BaseModel):
143
  """Augment API VCS更改"""
144
  working_directory_changes: List[Any] = []
145
 
146
+
147
  class AugmentFeatureFlags(BaseModel):
148
  """Augment API功能标志"""
149
  support_raw_output: bool = True
150
 
151
+
152
  # 完整的Augment API请求模型
153
  class AugmentChatRequest(BaseModel):
154
  """Augment API聊天请求模型 - 基于抓包分析更新"""
 
172
  feature_detection_flags: AugmentFeatureFlags = AugmentFeatureFlags()
173
  tool_definitions: List[Any] = []
174
  nodes: List[Any] = []
175
+ mode: str = "AGENT"
176
  agent_memories: Optional[Any] = None
177
  system_prompt: Optional[str] = None # 保留此字段以兼容之前的代码
178
 
179
+
180
  # Augment API响应模型
181
  class AugmentResponseChunk(BaseModel):
182
  """Augment API响应块"""
 
187
  incorporated_external_sources: List[Any] = []
188
  nodes: List[AugmentResponseNode] = []
189
 
190
+
191
  #################################################
192
  # 辅助函数
193
  #################################################
 
196
  """生成唯一ID,类似于OpenAI的格式"""
197
  return str(uuid.uuid4()).replace("-", "")[:24]
198
 
199
+
200
  def estimate_tokens(text):
201
  """
202
  估计文本的token数量
 
210
  chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') if text else 0
211
  return int(words * 1.3 + chinese_chars)
212
 
213
+
214
  def convert_to_augment_request(openai_request: ChatCompletionRequest) -> AugmentChatRequest:
215
  """
216
  将OpenAI API请求转换为Augment API请求
217
+
218
  Args:
219
  openai_request: OpenAI API请求对象
220
+
221
  Returns:
222
  转换后的Augment API请求对象
223
+
224
  Raises:
225
  HTTPException: 如果请求格式无效
226
  """
227
  chat_history = []
228
  system_message = None
229
+
230
  # 处理消息历史记录
231
  for i in range(len(openai_request.messages) - 1):
232
  msg = openai_request.messages[i]
233
  if msg.role == "system":
234
  system_message = msg.content
235
+ elif msg.role == "user" and i + 1 < len(openai_request.messages) and openai_request.messages[
236
+ i + 1].role == "assistant":
237
  user_msg = msg.content
238
  assistant_msg = openai_request.messages[i + 1].content
239
+
240
  # 创建历史记录条目,格式符合Augment API
241
  history_item = AugmentChatHistoryItem(
242
  request_message=user_msg,
 
252
  ]
253
  )
254
  chat_history.append(history_item)
255
+
256
  # 获取当前用户消息
257
  current_message = None
258
  for msg in reversed(openai_request.messages):
259
  if msg.role == "user":
260
  current_message = msg.content
261
  break
262
+
263
  # 如果没有用户消息,则返回错误
264
  if current_message is None:
265
  raise HTTPException(
266
  status_code=400,
267
  detail="At least one user message is required"
268
  )
269
+
270
  # 准备Augment请求体
271
  augment_request = AugmentChatRequest(
272
  message=current_message,
273
  chat_history=chat_history,
274
  mode="CHAT"
275
  )
276
+
277
  # 如果有系统消息,设置为用户指南
278
  if system_message:
279
  augment_request.user_guidelines = system_message
280
+
281
  return augment_request
282
 
283
+
284
  #################################################
285
  # FastAPI应用
286
  #################################################
 
288
  def create_app(augment_base_url, chat_endpoint, timeout):
289
  """
290
  创建并配置FastAPI应用
291
+
292
  Args:
293
  augment_base_url: Augment API基础URL
294
  chat_endpoint: 聊天端点路径
295
  timeout: 请求超时时间
296
+
297
  Returns:
298
  配置好的FastAPI应用
299
  """
 
338
  async def verify_api_key(authorization: str = Header(...)):
339
  """
340
  验证API密钥
341
+
342
  Args:
343
  authorization: Authorization头部值
344
+
345
  Returns:
346
  提取的API密钥
347
+
348
  Raises:
349
  HTTPException: 如果API密钥格式无效或为空
350
  """
351
  if not authorization.startswith("Bearer "):
352
  raise HTTPException(
353
+ status_code=401,
354
  detail={
355
  "error": {
356
  "message": "Invalid API key format. Expected 'Bearer YOUR_API_KEY'",
 
402
 
403
  @app.post("/v1/chat/completions")
404
  async def chat_completions(
405
+ request: ChatCompletionRequest,
406
+ api_key: str = Depends(verify_api_key)
407
  ):
408
  """
409
  聊天完成端点 - 将OpenAI API请求转换为Augment API请求
410
+
411
  Args:
412
  request: OpenAI格式的聊天完成请求
413
  api_key: 通过验证的API密钥
414
+
415
  Returns:
416
  OpenAI格式的聊天完成响应或流式响应
417
  """
 
420
  augment_request = convert_to_augment_request(request)
421
  logger.debug(f"Converted request: {augment_request.dict()}")
422
 
 
423
  if ":" in api_key:
424
  tenant_id, api_key = api_key.split(":")
425
  augment_base_url = f"https://{tenant_id}.api.augmentcode.com/"
 
427
  # 决定是否使用流式响应
428
  if request.stream:
429
  return StreamingResponse(
430
+ stream_augment_response(augment_base_url, api_key, augment_request, request.model, chat_endpoint,
431
+ timeout),
432
  media_type="text/event-stream"
433
  )
434
  else:
435
  # 同步请求处理
436
+ return await handle_sync_request(augment_base_url, api_key, augment_request, request.model,
437
+ chat_endpoint, timeout)
438
+
439
  except httpx.TimeoutException:
440
  logger.error("Request to Augment API timed out")
441
  raise HTTPException(
 
481
 
482
  return app
483
 
484
+
485
  async def handle_sync_request(base_url, api_key, augment_request, model_name, chat_endpoint, timeout):
486
  """
487
  处理同步请求
488
+
489
  Args:
490
  base_url: Augment API基础URL
491
  api_key: API密钥
 
493
  model_name: 模型名称
494
  chat_endpoint: 聊天端点
495
  timeout: 请求超时时间
496
+
497
  Returns:
498
  OpenAI格式的聊天完成响应
499
  """
 
508
  "Accept": "*/*"
509
  }
510
  )
511
+
512
  if response.status_code != 200:
513
  logger.error(f"Augment API error: {response.status_code} - {response.text}")
514
  raise HTTPException(
 
522
  }
523
  }
524
  )
525
+
526
  # 处理流式响应,合并为完整响应
527
  full_response = ""
528
  for line in response.text.split("\n"):
 
533
  full_response += data["text"]
534
  except json.JSONDecodeError:
535
  logger.warning(f"Failed to parse JSON: {line}")
536
+
537
  # 估算token使用情况
538
  prompt_tokens = estimate_tokens(augment_request.message)
539
  completion_tokens = estimate_tokens(full_response)
540
+
541
  # 构建OpenAI格式响应
542
  return ChatCompletionResponse(
543
  id=f"chatcmpl-{generate_id()}",
 
560
  )
561
  )
562
 
563
+
564
  async def stream_augment_response(base_url, api_key, augment_request, model_name, chat_endpoint, timeout):
565
  """
566
  处理流式响应
567
+
568
  Args:
569
  base_url: Augment API基础URL
570
  api_key: API密钥
 
572
  model_name: 模型名称
573
  chat_endpoint: 聊天端点
574
  timeout: 请求超时时间
575
+
576
  Yields:
577
  流式响应的数据块
578
  """
579
  async with httpx.AsyncClient(timeout=timeout) as client:
580
  try:
581
  async with client.stream(
582
+ "POST",
583
+ f"{base_url.rstrip('/')}/{chat_endpoint}",
584
+ json=augment_request.dict(),
585
+ headers={
586
+ "Content-Type": "application/json",
587
+ "Authorization": f"Bearer {api_key}",
588
+ "User-Agent": "chrome",
589
+ "Accept": "*/*"
590
+ }
591
  ) as response:
592
 
593
  if response.status_code != 200:
 
596
  error_message = f"Error from Augment API: {error_detail.decode('utf-8', errors='replace')}"
597
  yield f"data: {json.dumps({'error': error_message})}\n\n"
598
  return
599
+
600
  # 生成唯一ID
601
  chat_id = f"chatcmpl-{generate_id()}"
602
  created_time = int(time.time())
603
+
604
  # 初始化响应
605
  init_response = ChatCompletionStreamResponse(
606
  id=chat_id,
 
616
  )
617
  init_data = json.dumps(init_response.dict())
618
  yield f"data: {init_data}\n\n"
619
+
620
  # 处理流式响应
621
  buffer = ""
622
  async for line in response.aiter_lines():
623
  if not line.strip():
624
  continue
625
+
626
  try:
627
  # 解析Augment响应格式
628
  chunk = json.loads(line)
629
  if "text" in chunk and chunk["text"]:
630
  content = chunk["text"]
631
+
632
  # 发送增量更新
633
  stream_response = ChatCompletionStreamResponse(
634
  id=chat_id,
 
646
  yield f"data: {response_data}\n\n"
647
  except json.JSONDecodeError:
648
  logger.warning(f"Failed to parse JSON: {line}")
649
+
650
  # 发送完成信号
651
  final_response = ChatCompletionStreamResponse(
652
  id=chat_id,
 
662
  )
663
  final_data = json.dumps(final_response.dict())
664
  yield f"data: {final_data}\n\n"
665
+
666
  # 发送[DONE]标记
667
  yield "data: [DONE]\n\n"
668
+
669
  except httpx.TimeoutException:
670
  logger.error("Request to Augment API timed out")
671
  yield f"data: {json.dumps({'error': 'Request to Augment API timed out'})}\n\n"
 
676
  logger.exception("Unexpected error")
677
  yield f"data: {json.dumps({'error': f'Internal server error: {str(e)}'})}\n\n"
678
 
679
+
680
  def parse_args():
681
  """解析命令行参数"""
682
  parser = argparse.ArgumentParser(
683
  description="OpenAI to Augment API Adapter",
684
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
685
  )
686
+
687
  parser.add_argument(
688
+ "--augment-url",
689
  default="https://d6.api.augmentcode.com/",
690
  help="Augment API基础URL"
691
  )
692
+
693
  parser.add_argument(
694
+ "--chat-endpoint",
695
  default="chat-stream",
696
  help="Augment聊天端点路径"
697
  )
698
+
699
  parser.add_argument(
700
+ "--host",
701
  default="0.0.0.0",
702
  help="服务器主机地址"
703
  )
704
+
705
  parser.add_argument(
706
+ "--port",
707
+ type=int,
708
  default=8686,
709
  help="服务器端口"
710
  )
711
+
712
  parser.add_argument(
713
+ "--timeout",
714
+ type=int,
715
  default=120,
716
  help="API请求超时时间(秒)"
717
  )
718
+
719
  parser.add_argument(
720
+ "--debug",
721
  action="store_true",
722
  help="启用调试模式"
723
  )
 
727
  default="d18",
728
  help="Augment API租户ID (域名前缀)"
729
  )
730
+
731
  return parser.parse_args()
732
 
733
+
734
  #################################################
735
  # 主程序
736
  #################################################
 
738
  def main():
739
  """主函数"""
740
  args = parse_args()
741
+
742
  # 配置日志级别
743
  if args.debug:
744
  logging.getLogger().setLevel(logging.DEBUG)
745
+
746
  # 构建完整的Augment URL
747
  if args.augment_url == "https://d18.api.augmentcode.com/":
748
  # 如果使用默认URL,则应用tenant-id参数
 
751
  else:
752
  # 否则使用提供的URL
753
  augment_base_url = args.augment_url
754
+
755
  # 创建应用
756
  app = create_app(
757
  augment_base_url=augment_base_url,
758
  chat_endpoint=args.chat_endpoint,
759
  timeout=args.timeout
760
  )
761
+
762
  # 启动应用
763
  logger.info(f"Starting server on {args.host}:7860")
764
  logger.info(f"Using Augment base URL: {augment_base_url}")
765
  logger.info(f"Using Augment chat endpoint: {args.chat_endpoint}")
766
+
767
  uvicorn.run(
768
+ app,
769
+ host=args.host,
770
  port=7860,
771
  log_level="info" if not args.debug else "debug"
772
  )
773
 
774
+
775
  if __name__ == "__main__":
776
  main()