api-smollm135m / app.py
Reality123b's picture
Update app.py
1975705 verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List
import os
from huggingface_hub import InferenceClient
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Message(BaseModel):
role: str = Field(..., description="Role of the message sender (system/user/assistant)")
content: str = Field(..., description="Content of the message")
class ChatInput(BaseModel):
messages: List[Message] = Field(..., description="List of conversation messages")
max_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
temperature: float = Field(default=0.5, gt=0, le=2.0, description="Temperature for sampling")
top_p: float = Field(default=0.7, gt=0, le=1.0, description="Top-p sampling parameter")
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Hugging Face client
hf_client = InferenceClient(
model=os.getenv("MODEL_ID", "mistralai/Mistral-Nemo-Instruct-2407"), # default model added to client
token=os.getenv("HF_TOKEN"), # renamed api_key to token
timeout=30
)
MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
async def generate_stream(messages: List[Message], max_tokens: int, temperature: float, top_p: float):
"""Generate streaming response using Hugging Face Inference API."""
try:
# Convert messages to the format expected by the API
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
# Stream the response chunks
for chunk in hf_client.text_generation(
prompt= formatted_messages, # updated to text_generation
details=True,
max_new_tokens=max_tokens, # renamed max_tokens to max_new_tokens
temperature=temperature,
top_p=top_p,
do_sample=True,
stream=True,
):
if chunk.token.text is not None:
yield chunk.token.text
except Exception as e:
logger.error(f"Error in generate_stream: {e}", exc_info=True)
raise ValueError(f"Error generating response: {e}")
@app.post("/generate")
async def chat_stream(input: ChatInput, request: Request):
"""Stream chat completions based on the input messages."""
try:
if not os.getenv("HF_TOKEN"):
raise HTTPException(
status_code=500,
detail="HF_TOKEN environment variable not set"
)
logger.info(f"Received chat request from {request.client.host}")
logger.info(f"Number of messages: {len(input.messages)}")
return StreamingResponse(
generate_stream(
messages=input.messages,
max_tokens=input.max_tokens,
temperature=input.temperature,
top_p=input.top_p
),
media_type="text/event-stream"
)
except Exception as e:
logger.error(f"Error in chat_stream endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
"""Root endpoint that returns a welcome message."""
return {
"message": "Welcome to the Hugging Face Inference API Streaming Chat!",
"status": "running",
"model": MODEL_ID
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"model": MODEL_ID,
"hf_token_set": bool(os.getenv("HF_TOKEN"))
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")