Update main.py
Browse files
main.py
CHANGED
|
@@ -9,10 +9,8 @@ import httpx
|
|
| 9 |
import uvicorn
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
from fastapi import FastAPI, HTTPException, Depends
|
| 12 |
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
from pydantic import BaseModel
|
| 14 |
-
from starlette.
|
| 15 |
-
from starlette.responses import StreamingResponse, Response
|
| 16 |
|
| 17 |
logging.basicConfig(
|
| 18 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
@@ -21,41 +19,17 @@ logger = logging.getLogger(__name__)
|
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
app = FastAPI()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN","")
|
| 27 |
headers = {
|
| 28 |
-
'accept': '*/*',
|
| 29 |
-
'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
|
| 30 |
'authorization': f'Bearer {ACCESS_TOKEN}',
|
| 31 |
-
'cache-control': 'no-cache',
|
| 32 |
-
'origin': 'chrome-extension://dhoenijjpgpeimemopealfcbiecgceod',
|
| 33 |
-
'pragma': 'no-cache',
|
| 34 |
-
'priority': 'u=1, i',
|
| 35 |
-
'sec-fetch-dest': 'empty',
|
| 36 |
-
'sec-fetch-mode': 'cors',
|
| 37 |
-
'sec-fetch-site': 'none',
|
| 38 |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0',
|
| 39 |
}
|
| 40 |
|
| 41 |
ALLOWED_MODELS = [
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
{"id": "gemini-1.5-pro", "name": "gemini-1.5-pro"},
|
| 45 |
-
{"id": "gpt-4o", "name": "gpt-4o"},
|
| 46 |
-
{"id": "o1-preview", "name": "o1-preview"},
|
| 47 |
-
{"id": "o1-mini", "name": "o1-mini"},
|
| 48 |
-
{"id": "gpt-4o-mini", "name": "gpt-4o-mini"},
|
| 49 |
]
|
| 50 |
-
# Configure CORS
|
| 51 |
-
app.add_middleware(
|
| 52 |
-
CORSMiddleware,
|
| 53 |
-
allow_origins=["*"], # Allow all sources, you can restrict specific sources if needed
|
| 54 |
-
allow_credentials=True,
|
| 55 |
-
allow_methods=["*"], # All methods allowed
|
| 56 |
-
allow_headers=["*"], # Allow all headers
|
| 57 |
-
)
|
| 58 |
-
security = HTTPBearer()
|
| 59 |
|
| 60 |
|
| 61 |
class Message(BaseModel):
|
|
@@ -69,40 +43,6 @@ class ChatRequest(BaseModel):
|
|
| 69 |
stream: Optional[bool] = False
|
| 70 |
|
| 71 |
|
| 72 |
-
def simulate_data(content, model):
|
| 73 |
-
return {
|
| 74 |
-
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 75 |
-
"object": "chat.completion.chunk",
|
| 76 |
-
"created": int(datetime.now().timestamp()),
|
| 77 |
-
"model": model,
|
| 78 |
-
"choices": [
|
| 79 |
-
{
|
| 80 |
-
"index": 0,
|
| 81 |
-
"delta": {"content": content, "role": "assistant"},
|
| 82 |
-
"finish_reason": None,
|
| 83 |
-
}
|
| 84 |
-
],
|
| 85 |
-
"usage": None,
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def stop_data(content, model):
|
| 90 |
-
return {
|
| 91 |
-
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 92 |
-
"object": "chat.completion.chunk",
|
| 93 |
-
"created": int(datetime.now().timestamp()),
|
| 94 |
-
"model": model,
|
| 95 |
-
"choices": [
|
| 96 |
-
{
|
| 97 |
-
"index": 0,
|
| 98 |
-
"delta": {"content": content, "role": "assistant"},
|
| 99 |
-
"finish_reason": "stop",
|
| 100 |
-
}
|
| 101 |
-
],
|
| 102 |
-
"usage": None,
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
|
| 106 |
def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
| 107 |
return {
|
| 108 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
@@ -120,109 +60,36 @@ def create_chat_completion_data(content: str, model: str, finish_reason: Optiona
|
|
| 120 |
}
|
| 121 |
|
| 122 |
|
| 123 |
-
def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 124 |
-
if credentials.credentials != APP_SECRET:
|
| 125 |
-
raise HTTPException(status_code=403, detail="Invalid APP_SECRET")
|
| 126 |
-
return credentials.credentials
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
@app.options("/hf/v1/chat/completions")
|
| 130 |
-
async def chat_completions_options():
|
| 131 |
-
return Response(
|
| 132 |
-
status_code=200,
|
| 133 |
-
headers={
|
| 134 |
-
"Access-Control-Allow-Origin": "*",
|
| 135 |
-
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
| 136 |
-
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
| 137 |
-
},
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def replace_escaped_newlines(input_string: str) -> str:
|
| 142 |
-
return input_string.replace("\\n", "\n")
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
@app.get("/hf/v1/models")
|
| 146 |
-
async def list_models():
|
| 147 |
-
return {"object": "list", "data": ALLOWED_MODELS}
|
| 148 |
-
|
| 149 |
-
|
| 150 |
@app.post("/hf/v1/chat/completions")
|
| 151 |
-
async def chat_completions(
|
| 152 |
-
request: ChatRequest, app_secret: str = Depends(verify_app_secret)
|
| 153 |
-
):
|
| 154 |
-
# Log the model requested by the client
|
| 155 |
logger.info(f"Received chat completion request for model: {request.model}")
|
| 156 |
|
| 157 |
-
|
| 158 |
-
logger.info(f"Allowed models: {[model['id'] for model in ALLOWED_MODELS]}")
|
| 159 |
-
|
| 160 |
-
if request.model not in [model['id'] for model in ALLOWED_MODELS]:
|
| 161 |
logger.error(f"Model {request.model} is not allowed.")
|
| 162 |
raise HTTPException(
|
| 163 |
status_code=400,
|
| 164 |
-
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(
|
| 165 |
)
|
| 166 |
|
| 167 |
-
# Log the JSON payload for the request to the external API
|
| 168 |
-
logger.info(f"Sending request to external API with model: {request.model}")
|
| 169 |
-
|
| 170 |
-
# Generate a UUID
|
| 171 |
-
original_uuid = uuid.uuid4()
|
| 172 |
-
uuid_str = str(original_uuid).replace("-", "")
|
| 173 |
-
|
| 174 |
-
# Prepare json_data for the external request
|
| 175 |
json_data = {
|
| 176 |
'prompt': "\n".join(
|
| 177 |
-
[
|
| 178 |
-
f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}"
|
| 179 |
-
for msg in request.messages
|
| 180 |
-
]
|
| 181 |
),
|
| 182 |
'stream': True,
|
| 183 |
-
'app_name': 'ChitChat_Edge_Ext',
|
| 184 |
-
'app_version': '4.26.1',
|
| 185 |
-
'tz_name': 'Asia/Karachi',
|
| 186 |
-
'cid': '',
|
| 187 |
'model': request.model,
|
| 188 |
-
'search': False,
|
| 189 |
-
'auto_search': False,
|
| 190 |
-
'filter_search_history': False,
|
| 191 |
-
'from': 'chat',
|
| 192 |
-
'group_id': 'default',
|
| 193 |
-
'chat_models': [],
|
| 194 |
-
'files': [],
|
| 195 |
-
'prompt_template': {
|
| 196 |
-
'key': '',
|
| 197 |
-
'attributes': {
|
| 198 |
-
'lang': 'original',
|
| 199 |
-
},
|
| 200 |
-
},
|
| 201 |
-
'tools': {
|
| 202 |
-
'auto': [
|
| 203 |
-
'search',
|
| 204 |
-
'text_to_image',
|
| 205 |
-
'data_analysis',
|
| 206 |
-
],
|
| 207 |
-
},
|
| 208 |
-
'extra_info': {
|
| 209 |
-
'origin_url': '',
|
| 210 |
-
'origin_title': '',
|
| 211 |
-
},
|
| 212 |
}
|
|
|
|
|
|
|
| 213 |
|
| 214 |
async def generate():
|
| 215 |
async with httpx.AsyncClient() as client:
|
| 216 |
try:
|
| 217 |
-
# Log request details before making the API call
|
| 218 |
-
logger.info(f"External API request json_data: {json.dumps(json_data, indent=2)}")
|
| 219 |
-
|
| 220 |
async with client.stream('POST', 'https://sider.ai/api/v3/completion/text', headers=headers, json=json_data, timeout=120.0) as response:
|
| 221 |
response.raise_for_status()
|
| 222 |
async for line in response.aiter_lines():
|
| 223 |
if line and ("[DONE]" not in line):
|
| 224 |
content = json.loads(line[5:])["data"]
|
| 225 |
-
yield f"data: {json.dumps(create_chat_completion_data(content.get('text',''), request.model))}\n\n"
|
| 226 |
yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
|
| 227 |
yield "data: [DONE]\n\n"
|
| 228 |
except httpx.HTTPStatusError as e:
|
|
@@ -240,14 +107,11 @@ async def chat_completions(
|
|
| 240 |
full_response = ""
|
| 241 |
async for chunk in generate():
|
| 242 |
if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
|
| 243 |
-
# Parse the chunk data and log it for debugging
|
| 244 |
data = json.loads(chunk[6:])
|
| 245 |
-
logger.info(f"Chunk data received: {data}")
|
| 246 |
-
|
| 247 |
if data["choices"][0]["delta"].get("content"):
|
| 248 |
full_response += data["choices"][0]["delta"]["content"]
|
| 249 |
|
| 250 |
-
logger.info(f"Full response generated
|
| 251 |
return {
|
| 252 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 253 |
"object": "chat.completion",
|
|
@@ -262,3 +126,7 @@ async def chat_completions(
|
|
| 262 |
],
|
| 263 |
"usage": None,
|
| 264 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import uvicorn
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
from fastapi import FastAPI, HTTPException, Depends
|
|
|
|
| 12 |
from pydantic import BaseModel
|
| 13 |
+
from starlette.responses import StreamingResponse
|
|
|
|
| 14 |
|
| 15 |
logging.basicConfig(
|
| 16 |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
app = FastAPI()
|
| 22 |
+
APP_SECRET = os.getenv("APP_SECRET", "666")
|
| 23 |
+
ACCESS_TOKEN = os.getenv("SD_ACCESS_TOKEN", "")
|
|
|
|
| 24 |
headers = {
|
|
|
|
|
|
|
| 25 |
'authorization': f'Bearer {ACCESS_TOKEN}',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36 Edg/129.0.0.0',
|
| 27 |
}
|
| 28 |
|
| 29 |
ALLOWED_MODELS = [
|
| 30 |
+
"claude-3.5-sonnet", "sider", "gpt-4o-mini", "claude-3-haiku", "claude-3.5-haiku",
|
| 31 |
+
"gemini-1.5-flash", "llama-3", "gpt-4o", "gemini-1.5-pro", "llama-3.1-405b"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class Message(BaseModel):
|
|
|
|
| 43 |
stream: Optional[bool] = False
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def create_chat_completion_data(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
| 47 |
return {
|
| 48 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
|
|
| 60 |
}
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
@app.post("/hf/v1/chat/completions")
|
| 64 |
+
async def chat_completions(request: ChatRequest):
|
|
|
|
|
|
|
|
|
|
| 65 |
logger.info(f"Received chat completion request for model: {request.model}")
|
| 66 |
|
| 67 |
+
if request.model not in ALLOWED_MODELS:
|
|
|
|
|
|
|
|
|
|
| 68 |
logger.error(f"Model {request.model} is not allowed.")
|
| 69 |
raise HTTPException(
|
| 70 |
status_code=400,
|
| 71 |
+
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(ALLOWED_MODELS)}",
|
| 72 |
)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
json_data = {
|
| 75 |
'prompt': "\n".join(
|
| 76 |
+
[f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}" for msg in request.messages]
|
|
|
|
|
|
|
|
|
|
| 77 |
),
|
| 78 |
'stream': True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
'model': request.model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
}
|
| 81 |
+
|
| 82 |
+
logger.info(f"Sending request to external API with data: {json_data}")
|
| 83 |
|
| 84 |
async def generate():
|
| 85 |
async with httpx.AsyncClient() as client:
|
| 86 |
try:
|
|
|
|
|
|
|
|
|
|
| 87 |
async with client.stream('POST', 'https://sider.ai/api/v3/completion/text', headers=headers, json=json_data, timeout=120.0) as response:
|
| 88 |
response.raise_for_status()
|
| 89 |
async for line in response.aiter_lines():
|
| 90 |
if line and ("[DONE]" not in line):
|
| 91 |
content = json.loads(line[5:])["data"]
|
| 92 |
+
yield f"data: {json.dumps(create_chat_completion_data(content.get('text', ''), request.model))}\n\n"
|
| 93 |
yield f"data: {json.dumps(create_chat_completion_data('', request.model, 'stop'))}\n\n"
|
| 94 |
yield "data: [DONE]\n\n"
|
| 95 |
except httpx.HTTPStatusError as e:
|
|
|
|
| 107 |
full_response = ""
|
| 108 |
async for chunk in generate():
|
| 109 |
if chunk.startswith("data: ") and not chunk[6:].startswith("[DONE]"):
|
|
|
|
| 110 |
data = json.loads(chunk[6:])
|
|
|
|
|
|
|
| 111 |
if data["choices"][0]["delta"].get("content"):
|
| 112 |
full_response += data["choices"][0]["delta"]["content"]
|
| 113 |
|
| 114 |
+
logger.info(f"Full response generated: {full_response}")
|
| 115 |
return {
|
| 116 |
"id": f"chatcmpl-{uuid.uuid4()}",
|
| 117 |
"object": "chat.completion",
|
|
|
|
| 126 |
],
|
| 127 |
"usage": None,
|
| 128 |
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|