from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from typing import List import os from huggingface_hub import InferenceClient import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class Message(BaseModel): role: str = Field(..., description="Role of the message sender (system/user/assistant)") content: str = Field(..., description="Content of the message") class ChatInput(BaseModel): messages: List[Message] = Field(..., description="List of conversation messages") max_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate") temperature: float = Field(default=0.5, gt=0, le=2.0, description="Temperature for sampling") top_p: float = Field(default=0.7, gt=0, le=1.0, description="Top-p sampling parameter") app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize Hugging Face client hf_client = InferenceClient( model=os.getenv("MODEL_ID", "mistralai/Mistral-Nemo-Instruct-2407"), # default model added to client token=os.getenv("HF_TOKEN"), # renamed api_key to token timeout=30 ) MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407" async def generate_stream(messages: List[Message], max_tokens: int, temperature: float, top_p: float): """Generate streaming response using Hugging Face Inference API.""" try: # Convert messages to the format expected by the API formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages] # Stream the response chunks for chunk in hf_client.text_generation( prompt= formatted_messages, # updated to text_generation details=True, max_new_tokens=max_tokens, # renamed max_tokens to max_new_tokens temperature=temperature, top_p=top_p, do_sample=True, stream=True, ): if chunk.token.text is not None: yield chunk.token.text except Exception as e: logger.error(f"Error in generate_stream: {e}", exc_info=True) raise ValueError(f"Error generating response: {e}") @app.post("/generate") async def chat_stream(input: ChatInput, request: Request): """Stream chat completions based on the input messages.""" try: if not os.getenv("HF_TOKEN"): raise HTTPException( status_code=500, detail="HF_TOKEN environment variable not set" ) logger.info(f"Received chat request from {request.client.host}") logger.info(f"Number of messages: {len(input.messages)}") return StreamingResponse( generate_stream( messages=input.messages, max_tokens=input.max_tokens, temperature=input.temperature, top_p=input.top_p ), media_type="text/event-stream" ) except Exception as e: logger.error(f"Error in chat_stream endpoint: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): """Root endpoint that returns a welcome message.""" return { "message": "Welcome to the Hugging Face Inference API Streaming Chat!", "status": "running", "model": MODEL_ID } @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "model": MODEL_ID, "hf_token_set": bool(os.getenv("HF_TOKEN")) } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")