Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Depends, Header | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
from typing import List, Optional, Literal | |
import json | |
import g4f | |
from g4f.Provider import OpenaiAccount, RetryProvider | |
app = FastAPI() | |
# Organized model list (updated with all G4F available models) | |
MODELS = { | |
# OpenAI | |
"openai": [ | |
"gpt-4o", "gpt-4o-mini", "gpt-4", "gpt-4-turbo", | |
"gpt-3.5-turbo", "gpt-3.5-turbo-16k" | |
], | |
# Anthropic | |
"anthropic": [ | |
"claude-3-opus", "claude-3-sonnet", "claude-3-haiku", | |
"claude-3.5", "claude-3.7-sonnet", "claude-2.1" | |
], | |
"google": [ | |
"gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", | |
"gemini-2.5-pro-exp-03-25" | |
], | |
# Meta | |
"meta": [ | |
"llama-3-70b", "llama-3-8b", "llama-3.1-405b", | |
"llama-2-70b", "llama-2-13b", "llama-2-7b" | |
], | |
# XAI (Grok) | |
"xai": [ | |
"grok-1", "grok-1.5", "grok-2", "grok-3" | |
], | |
# Other | |
"other": [ | |
"o1", "o3-mini", "mistral-7b", "mixtral-8x7b", | |
"command-r-plus", "deepseek-chat", "code-llama-70b" | |
], | |
# Image Models | |
"image": [ | |
"dall-e-3", "stable-diffusion-xl", | |
"flux", "flux-pro", "playground-v2.5" | |
] | |
} | |
# Flattened list for API endpoint | |
ALL_MODELS = [ | |
*MODELS["openai"], | |
*MODELS["anthropic"], | |
*MODELS["google"], | |
*MODELS["meta"], | |
*MODELS["xai"], | |
*MODELS["other"] | |
] | |
# Pydantic Models | |
class Message(BaseModel): | |
role: Literal["system", "user", "assistant"] | |
content: str | |
class ChatRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
temperature: Optional[float] = 0.7 | |
max_tokens: Optional[int] = None | |
top_p: Optional[float] = 0.9 | |
stream: Optional[bool] = True | |
class ModelListResponse(BaseModel): | |
openai: List[str] | |
anthropic: List[str] | |
google: List[str] | |
meta: List[str] | |
xai: List[str] | |
other: List[str] | |
image: List[str] | |
# API Endpoints | |
async def get_models(): | |
"""Get all available models categorized by provider""" | |
return ModelListResponse(**MODELS) | |
async def chat_completion(request: ChatRequest): | |
"""Handle chat completion requests""" | |
if request.model not in ALL_MODELS: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid model. Available: {ALL_MODELS}" | |
) | |
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
try: | |
if request.stream: | |
async def stream_generator(): | |
response = await g4f.ChatCompletion.create_async( | |
model=request.model, | |
messages=messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
max_tokens=request.max_tokens, | |
provider=RetryProvider([g4f.Provider.BackendApi]) | |
) | |
async for chunk in response: | |
yield f"data: {json.dumps({'content': str(chunk)})}\n\n" | |
yield "data: [DONE]\n\n" | |
return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
else: | |
response = await g4f.ChatCompletion.create_async( | |
model=request.model, | |
messages=messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
provider=RetryProvider([g4f.Provider.BackendApi]) | |
) | |
return {"content": str(response)} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Image Generation Endpoint | |
async def generate_image(prompt: str, model: str = "dall-e-3"): | |
if model not in MODELS["image"]: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid image model. Available: {MODELS['image']}" | |
) | |
try: | |
if model in ["flux", "flux-pro"]: | |
image_data = g4f.ImageGeneration.create( | |
prompt=prompt, | |
model=model, | |
provider=g4f.Provider.BackendApi | |
) | |
return {"url": f"data:image/png;base64,{image_data.decode('utf-8')}"} | |
else: | |
# Implementation for other image providers | |
raise HTTPException( | |
status_code=501, | |
detail=f"{model} implementation pending" | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |