api-smollm135m / app.py
Reality123b's picture
Update app.py
fc4b315 verified
raw
history blame
3.82 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List
import os
from huggingface_hub import InferenceClient
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Message(BaseModel):
role: str = Field(..., description="Role of the message sender (system/user/assistant)")
content: str = Field(..., description="Content of the message")
class ChatInput(BaseModel):
messages: List[Message] = Field(..., description="List of conversation messages")
max_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
temperature: float = Field(default=0.5, gt=0, le=2.0, description="Temperature for sampling")
top_p: float = Field(default=0.7, gt=0, le=1.0, description="Top-p sampling parameter")
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Hugging Face client
hf_client = InferenceClient(
api_key=os.getenv("HF_TOKEN"),
timeout=30
)
MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
async def generate_stream(messages: List[Message], max_tokens: int, temperature: float, top_p: float):
"""Generate streaming response using Hugging Face Inference API."""
try:
# Convert messages to the format expected by the API
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
# Create the streaming completion
stream = hf_client.chat.completions.create(
model=MODEL_ID,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
# Stream the response chunks
for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"Error in generate_stream: {e}", exc_info=True)
raise ValueError(f"Error generating response: {e}")
@app.post("/generate")
async def chat_stream(input: ChatInput, request: Request):
"""Stream chat completions based on the input messages."""
try:
if not os.getenv("HF_TOKEN"):
raise HTTPException(
status_code=500,
detail="HF_TOKEN environment variable not set"
)
logger.info(f"Received chat request from {request.client.host}")
logger.info(f"Number of messages: {len(input.messages)}")
return StreamingResponse(
generate_stream(
messages=input.messages,
max_tokens=input.max_tokens,
temperature=input.temperature,
top_p=input.top_p
),
media_type="text/event-stream"
)
except Exception as e:
logger.error(f"Error in chat_stream endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Root endpoint that returns a welcome message."""
return {
"message": "Welcome to the Hugging Face Inference API Streaming Chat!",
"status": "running",
"model": MODEL_ID
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model": MODEL_ID,
"hf_token_set": bool(os.getenv("HF_TOKEN"))
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")