Update main.py
Browse files
main.py
CHANGED
@@ -9,10 +9,8 @@ import httpx
|
|
9 |
import uvicorn
|
10 |
from dotenv import load_dotenv
|
11 |
from fastapi import FastAPI, HTTPException, Depends
|
12 |
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from pydantic import BaseModel
|
14 |
-
from starlette.
|
15 |
-
from starlette.responses import StreamingResponse, Response
|
16 |
|
17 |
logging.basicConfig(
|
18 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
@@ -21,41 +19,17 @@ logger = logging.getLogger(__name__)
|
|
21 |
|
22 |
load_dotenv()
|
23 |
app = FastAPI()
|
24 |
-
|
25 |
-
|
26 |
-
ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN","")
|
27 |
headers = {
|
28 |
-
'accept': '*/*',
|
29 |
-
'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
|
30 |
'authorization': f'Bearer {ACCESS_TOKEN}',
|
31 |
-
'cache-control': 'no-cache',
|
32 |
-
'origin': 'chrome-extension://dhoenijjpgpeimemopealfcbiecgceod',
|
33 |
-
'pragma': 'no-cache',
|
34 |
-
'priority': 'u=1, i',
|
35 |
-
'sec-fetch-dest': 'empty',
|
36 |
-
'sec-fetch-mode': 'cors',
|
37 |
-
'sec-fetch-site': 'none',
|
38 |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0',
|
39 |
}
|
40 |
|
41 |
ALLOWED_MODELS = [
|
42 |
-
|
43 |
-
|
44 |
-
{"id": "gemini-1.5-pro", "name": "gemini-1.5-pro"},
|
45 |
-
{"id": "gpt-4o", "name": "gpt-4o"},
|
46 |
-
{"id": "o1-preview", "name": "o1-preview"},
|
47 |
-
{"id": "o1-mini", "name": "o1-mini"},
|
48 |
-
{"id": "gpt-4o-mini", "name": "gpt-4o-mini"},
|
49 |
]
|
50 |
-
# Configure CORS
|
51 |
-
app.add_middleware(
|
52 |
-
CORSMiddleware,
|
53 |
-
allow_origins=["*"], # Allow all sources, you can restrict specific sources if needed
|
54 |
-
allow_credentials=True,
|
55 |
-
allow_methods=["*"], # All methods allowed
|
56 |
-
allow_headers=["*"], # Allow all headers
|
57 |
-
)
|
58 |
-
security = HTTPBearer()
|
59 |
|
60 |
|
61 |
class Message(BaseModel):
|
@@ -69,40 +43,6 @@ class ChatRequest(BaseModel):
|
|
69 |
stream: Optional[bool] = False
|
70 |
|
71 |
|
72 |
-
def simulate_data(content, model):
|
73 |
-
return {
|
74 |
-
"id": f"chatcmpl-{uuid.uuid4()}",
|
75 |
-
"object": "chat.completion.chunk",
|
76 |
-
"created": int(datetime.now().timestamp()),
|
77 |
-
"model": model,
|
78 |
-
"choices": [
|
79 |
-
{
|
80 |
-
"index": 0,
|
81 |
-
"delta": {"content": content, "role": "assistant"},
|
82 |
-
"finish_reason": None,
|
83 |
-
}
|
84 |
-
],
|
85 |
-
"usage": None,
|
86 |
-
}
|
87 |
-
|
88 |
-
|
89 |
-
def stop_data(content, model):
|
90 |
-
return {
|
91 |
-
"id": f"chatcmpl-{uuid.uuid4()}",
|
92 |
-
"object": "chat.completion.chunk",
|
93 |
-
"created": int(datetime.now().timestamp()),
|
94 |
-
"model": model,
|
95 |
-
"choices": [
|
96 |
-
{
|
97 |
-
"index": 0,
|
98 |
-
"delta": {"content": content, "role": "assistant"},
|
99 |
-
"finish_reason": "stop",
|
100 |
-
}
|
101 |
-
],
|
102 |
-
"usage": None,
|
103 |
-
}
|
104 |
-
|
105 |
-
|
106 |
def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
107 |
return {
|
108 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
@@ -120,109 +60,36 @@ def create_chat_completion_data(content: str, model: str, finish_reason: Optiona
|
|
120 |
}
|
121 |
|
122 |
|
123 |
-
def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
124 |
-
if credentials.credentials != APP_SECRET:
|
125 |
-
raise HTTPException(status_code=403, detail="Invalid APP_SECRET")
|
126 |
-
return credentials.credentials
|
127 |
-
|
128 |
-
|
129 |
-
@app.options("/hf/v1/chat/completions")
|
130 |
-
async def chat_completions_options():
|
131 |
-
return Response(
|
132 |
-
status_code=200,
|
133 |
-
headers={
|
134 |
-
"Access-Control-Allow-Origin": "*",
|
135 |
-
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
136 |
-
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
137 |
-
},
|
138 |
-
)
|
139 |
-
|
140 |
-
|
141 |
-
def replace_escaped_newlines(input_string: str) -> str:
|
142 |
-
return input_string.replace("\\n", "\n")
|
143 |
-
|
144 |
-
|
145 |
-
@app.get("/hf/v1/models")
|
146 |
-
async def list_models():
|
147 |
-
return {"object": "list", "data": ALLOWED_MODELS}
|
148 |
-
|
149 |
-
|
150 |
@app.post("/hf/v1/chat/completions")
|
151 |
-
async def chat_completions(
|
152 |
-
request: ChatRequest, app_secret: str = Depends(verify_app_secret)
|
153 |
-
):
|
154 |
-
# Log the model requested by the client
|
155 |
logger.info(f"Received chat completion request for model: {request.model}")
|
156 |
|
157 |
-
|
158 |
-
logger.info(f"Allowed models: {[model['id'] for model in ALLOWED_MODELS]}")
|
159 |
-
|
160 |
-
if request.model not in [model['id'] for model in ALLOWED_MODELS]:
|
161 |
logger.error(f"Model {request.model} is not allowed.")
|
162 |
raise HTTPException(
|
163 |
status_code=400,
|
164 |
-
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(
|
165 |
)
|
166 |
|
167 |
-
# Log the JSON payload for the request to the external API
|
168 |
-
logger.info(f"Sending request to external API with model: {request.model}")
|
169 |
-
|
170 |
-
# Generate a UUID
|
171 |
-
original_uuid = uuid.uuid4()
|
172 |
-
uuid_str = str(original_uuid).replace("-", "")
|
173 |
-
|
174 |
-
# Prepare json_data for the external request
|
175 |
json_data = {
|
176 |
'prompt': "\n".join(
|
177 |
-
[
|
178 |
-
f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}"
|
179 |
-
for msg in request.messages
|
180 |
-
]
|
181 |
),
|
182 |
'stream': True,
|
183 |
-
'app_name': 'ChitChat_Edge_Ext',
|
184 |
-
'app_version': '4.26.1',
|
185 |
-
'tz_name': 'Asia/Karachi',
|
186 |
-
'cid': '',
|
187 |
'model': request.model,
|
188 |
-
'search': False,
|
189 |
-
'auto_search': False,
|
190 |
-
'filter_search_history': False,
|
191 |
-
'from': 'chat',
|
192 |
-
'group_id': 'default',
|
193 |
-
'chat_models': [],
|
194 |
-
'files': [],
|
195 |
-
'prompt_template': {
|
196 |
-
'key': '',
|
197 |
-
'attributes': {
|
198 |
-
'lang': 'original',
|
199 |
-
},
|
200 |
-
},
|
201 |
-
'tools': {
|
202 |
-
'auto': [
|
203 |
-
'search',
|
204 |
-
'text_to_image',
|
205 |
-
'data_analysis',
|
206 |
-
],
|
207 |
-
},
|
208 |
-
'extra_info': {
|
209 |
-
'origin_url': '',
|
210 |
-
'origin_title': '',
|
211 |
-
},
|
212 |
}
|
|
|
|
|
213 |
|
214 |
async def generate():
|
215 |
async with httpx.AsyncClient() as client:
|
216 |
try:
|
217 |
-
# Log request details before making the API call
|
218 |
-
logger.info(f"External API request json_data: {json.dumps(json_data, indent=2)}")
|
219 |
-
|
220 |
async with client.stream('POST', 'https://sider.ai/api/v3/completion/text', headers=headers, json=json_data, timeout=120.0) as response:
|
221 |
response.raise_for_status()
|
222 |
async for line in response.aiter_lines():
|
223 |
if line and ("[DONE]" not in line):
|
224 |
content = json.loads(line[5:])["data"]
|
225 |
-
yield f"data: {json.dumps(create_chat_completion_data(content.get('text',''), request.model))}\n\n"
|
226 |
yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
|
227 |
yield "data: [DONE]\n\n"
|
228 |
except httpx.HTTPStatusError as e:
|
@@ -240,14 +107,11 @@ async def chat_completions(
|
|
240 |
full_response = ""
|
241 |
async for chunk in generate():
|
242 |
if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
|
243 |
-
# Parse the chunk data and log it for debugging
|
244 |
data = json.loads(chunk[6:])
|
245 |
-
logger.info(f"Chunk data received: {data}")
|
246 |
-
|
247 |
if data["choices"][0]["delta"].get("content"):
|
248 |
full_response += data["choices"][0]["delta"]["content"]
|
249 |
|
250 |
-
logger.info(f"Full response generated
|
251 |
return {
|
252 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
253 |
"object": "chat.completion",
|
@@ -262,3 +126,7 @@ async def chat_completions(
|
|
262 |
],
|
263 |
"usage": None,
|
264 |
}
|
|
|
|
|
|
|
|
|
|
9 |
import uvicorn
|
10 |
from dotenv import load_dotenv
|
11 |
from fastapi import FastAPI, HTTPException, Depends
|
|
|
12 |
from pydantic import BaseModel
|
13 |
+
from starlette.responses import StreamingResponse
|
|
|
14 |
|
15 |
logging.basicConfig(
|
16 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
19 |
|
20 |
load_dotenv()
|
21 |
app = FastAPI()
|
22 |
+
APP_SECRET = os.getenv("APP_SECRET", "666")
|
23 |
+
ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "")
|
|
|
24 |
headers = {
|
|
|
|
|
25 |
'authorization': f'Bearer {ACCESS_TOKEN}',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0',
|
27 |
}
|
28 |
|
29 |
ALLOWED_MODELS = [
|
30 |
+
"claude-3.5-sonnet", "sider", "gpt-4o-mini", "claude-3-haiku", "claude-3.5-haiku",
|
31 |
+
"gemini-1.5-flash", "llama-3", "gpt-4o", "gemini-1.5-pro", "llama-3.1-405b"
|
|
|
|
|
|
|
|
|
|
|
32 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
class Message(BaseModel):
|
|
|
43 |
stream: Optional[bool] = False
|
44 |
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
47 |
return {
|
48 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
|
60 |
}
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
@app.post("/hf/v1/chat/completions")
|
64 |
+
async def chat_completions(request: ChatRequest):
|
|
|
|
|
|
|
65 |
logger.info(f"Received chat completion request for model: {request.model}")
|
66 |
|
67 |
+
if request.model not in ALLOWED_MODELS:
|
|
|
|
|
|
|
68 |
logger.error(f"Model {request.model} is not allowed.")
|
69 |
raise HTTPException(
|
70 |
status_code=400,
|
71 |
+
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(ALLOWED_MODELS)}",
|
72 |
)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
json_data = {
|
75 |
'prompt': "\n".join(
|
76 |
+
[f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}" for msg in request.messages]
|
|
|
|
|
|
|
77 |
),
|
78 |
'stream': True,
|
|
|
|
|
|
|
|
|
79 |
'model': request.model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
}
|
81 |
+
|
82 |
+
logger.info(f"Sending request to external API with data: {json_data}")
|
83 |
|
84 |
async def generate():
|
85 |
async with httpx.AsyncClient() as client:
|
86 |
try:
|
|
|
|
|
|
|
87 |
async with client.stream('POST', 'https://sider.ai/api/v3/completion/text', headers=headers, json=json_data, timeout=120.0) as response:
|
88 |
response.raise_for_status()
|
89 |
async for line in response.aiter_lines():
|
90 |
if line and ("[DONE]" not in line):
|
91 |
content = json.loads(line[5:])["data"]
|
92 |
+
yield f"data: {json.dumps(create_chat_completion_data(content.get('text', ''), request.model))}\n\n"
|
93 |
yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
|
94 |
yield "data: [DONE]\n\n"
|
95 |
except httpx.HTTPStatusError as e:
|
|
|
107 |
full_response = ""
|
108 |
async for chunk in generate():
|
109 |
if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
|
|
|
110 |
data = json.loads(chunk[6:])
|
|
|
|
|
111 |
if data["choices"][0]["delta"].get("content"):
|
112 |
full_response += data["choices"][0]["delta"]["content"]
|
113 |
|
114 |
+
logger.info(f"Full response generated: {full_response}")
|
115 |
return {
|
116 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
117 |
"object": "chat.completion",
|
|
|
126 |
],
|
127 |
"usage": None,
|
128 |
}
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|