File size: 7,095 Bytes
76b9762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

import datetime
import json
import re
import time
from typing import Any, AsyncGenerator, Dict, Union

from app.config.config import settings
from app.database.services import (
    add_error_log,
    add_request_log,
)
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.service.client.api_client import OpenaiApiClient
from app.service.key.key_manager import KeyManager
from app.log.logger import get_openai_compatible_logger

logger = get_openai_compatible_logger()

class OpenAICompatiableService:

    def __init__(self, base_url: str, key_manager: KeyManager = None):
        self.key_manager = key_manager
        self.base_url = base_url
        self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
        
    async def get_models(self, api_key: str) -> Dict[str, Any]:
        return await self.api_client.get_models(api_key)

    async def create_chat_completion(
        self,
        request: ChatRequest,
        api_key: str,
    ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
        """创建聊天完成"""
        request_dict = request.model_dump()
        # 移除值为null的
        request_dict = {k: v for k, v in request_dict.items() if v is not None}
        del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
        if request.stream:
            return self._handle_stream_completion(request.model, request_dict, api_key)
        return await self._handle_normal_completion(request.model, request_dict, api_key)

    async def generate_images(
        self,
        request: ImageGenerationRequest,
    ) -> Dict[str, Any]:
        """生成图片"""
        request_dict = request.model_dump()
        # 移除值为null的
        request_dict = {k: v for k, v in request_dict.items() if v is not None}
        api_key = settings.PAID_KEY
        return await self.api_client.generate_images(request_dict, api_key)

    async def create_embeddings(
        self,
        input_text: str,
        model: str,
        api_key: str,
    ) -> Dict[str, Any]:
        """创建嵌入"""
        return await self.api_client.create_embeddings(input_text, model, api_key)

    async def _handle_normal_completion(
        self, model: str, request: dict, api_key: str
    ) -> Dict[str, Any]:
        """处理普通聊天完成"""
        start_time = time.perf_counter()
        request_datetime = datetime.datetime.now()
        is_success = False
        status_code = None
        response = None
        try:
            response = await self.api_client.generate_content(request, api_key)
            is_success = True
            status_code = 200
            return response
        except Exception as e:
            is_success = False
            error_log_msg = str(e)
            logger.error(f"Normal API call failed with error: {error_log_msg}")
            match = re.search(r"status code (\d+)", error_log_msg)
            if match:
                status_code = int(match.group(1))
            else:
                status_code = 500

            await add_error_log(
                gemini_key=api_key,
                model_name=model,
                error_type="openai-compatiable-non-stream",
                error_log=error_log_msg,
                error_code=status_code,
                request_msg=request,
            )
            raise e
        finally:
            end_time = time.perf_counter()
            latency_ms = int((end_time - start_time) * 1000)
            await add_request_log(
                model_name=model,
                api_key=api_key,
                is_success=is_success,
                status_code=status_code,
                latency_ms=latency_ms,
                request_time=request_datetime,
            )

    async def _handle_stream_completion(
        self, model: str, payload: dict, api_key: str
    ) -> AsyncGenerator[str, None]:
        """处理流式聊天完成,添加重试逻辑"""
        retries = 0
        max_retries = settings.MAX_RETRIES
        is_success = False
        status_code = None
        final_api_key = api_key

        while retries < max_retries:
            start_time = time.perf_counter()
            request_datetime = datetime.datetime.now()
            current_attempt_key = api_key
            final_api_key = current_attempt_key
            try:
                async for line in self.api_client.stream_generate_content(
                    payload, current_attempt_key
                ):
                    if line.startswith("data:"):
                        # print(line)
                        yield line + "\n\n"
                logger.info("Streaming completed successfully")
                is_success = True
                status_code = 200
                break
            except Exception as e:
                retries += 1
                is_success = False
                error_log_msg = str(e)
                logger.warning(
                    f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
                )
                match = re.search(r"status code (\d+)", error_log_msg)
                if match:
                    status_code = int(match.group(1))
                else:
                    status_code = 500

                await add_error_log(
                    gemini_key=current_attempt_key,
                    model_name=model,
                    error_type="openai-compatiable-stream",
                    error_log=error_log_msg,
                    error_code=status_code,
                    request_msg=payload,
                )

                if self.key_manager:
                    api_key = await self.key_manager.handle_api_failure(
                        current_attempt_key, retries
                    )
                    if api_key:
                        logger.info(f"Switched to new API key: {api_key}")
                    else:
                        logger.error(
                            f"No valid API key available after {retries} retries."
                        )
                        break 
                else:
                    logger.error("KeyManager not available for retry logic.")
                    break 

                if retries >= max_retries:
                    logger.error(f"Max retries ({max_retries}) reached for streaming.")
                    break
            finally:
                end_time = time.perf_counter()
                latency_ms = int((end_time - start_time) * 1000)
                await add_request_log(
                    model_name=model,
                    api_key=final_api_key,
                    is_success=is_success,
                    status_code=status_code,
                    latency_ms=latency_ms,
                    request_time=request_datetime,
                )
                if not is_success and retries >= max_retries:
                    yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
                    yield "data: [DONE]\n\n"