vision / app.py
abdullahalioo's picture
Update app.py
6e02eb7 verified
raw
history blame
2.42 kB
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from openai import AsyncOpenAI
from pydantic import BaseModel
# Initialize FastAPI app
app = FastAPI()
# Define request body model for the prompt
class PromptRequest(BaseModel):
prompt: str
# Initialize OpenAI client
token = os.getenv("GITHUB_TOKEN")
if not token:
raise ValueError("GITHUB_TOKEN environment variable not set")
# Use the correct endpoint for GitHub Models or fallback to a compatible OpenAI-like API
endpoint = os.getenv("API_ENDPOINT", "https://api.github.com/models") # Adjust based on GitHub Models documentation
model = os.getenv("MODEL_NAME", "gpt-4o-mini") # Use a valid model name, e.g., gpt-4o-mini or equivalent
# Initialize AsyncOpenAI client without proxies to avoid TypeError
client = AsyncOpenAI(
base_url=endpoint,
api_key=token,
# Explicitly disable proxies if not needed
http_client=None # Avoid passing unexpected kwargs like proxies
)
# Async generator to stream chunks
async def stream_response(prompt: str):
try:
# Create streaming chat completion
stream = await client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
temperature=1.0,
top_p=1.0,
model=model,
stream=True
)
# Yield each chunk as it arrives
async for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
content = chunk.choices[0].delta.content or ""
if content: # Only yield non-empty content
yield content
except Exception as err:
yield f"Error: {str(err)}"
# Endpoint to handle prompt and stream response
@app.post("/generate")
async def generate_response(request: PromptRequest):
try:
# Return a StreamingResponse with the async generator
return StreamingResponse(
stream_response(request.prompt),
media_type="text/event-stream" # Use text/event-stream for streaming
)
except Exception as err:
raise HTTPException(status_code=500, detail=f"Server error: {str(err)}")
# Health check endpoint for Hugging Face Spaces
@app.get("/")
async def health_check():
return {"status": "healthy"}