File size: 3,927 Bytes
fbf5fda
 
ee9527e
fbf5fda
ee9527e
 
 
fbf5fda
 
 
 
 
6e0397b
ee9527e
 
 
 
 
 
 
 
 
6e0397b
 
 
fbf5fda
 
 
 
 
 
 
 
 
ee9527e
 
1975705
 
ee9527e
 
6e0397b
ee9527e
6e0397b
ee9527e
 
6e0397b
ee9527e
 
3b63b4a
ee9527e
1975705
 
 
 
 
 
 
 
 
 
 
ee9527e
6e0397b
ee9527e
6e0397b
 
fc4b315
ee9527e
 
6e0397b
ee9527e
 
1975705
ee9527e
 
fbf5fda
ee9527e
 
1975705
ee9527e
 
 
 
 
 
 
 
6e0397b
 
ee9527e
6e0397b
 
 
 
8faa1c2
ee9527e
 
 
 
 
fbf5fda
 
 
 
 
 
ee9527e
 
a04b12b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List
import os
from huggingface_hub import InferenceClient
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Message(BaseModel):
    role: str = Field(..., description="Role of the message sender (system/user/assistant)")
    content: str = Field(..., description="Content of the message")

class ChatInput(BaseModel):
    messages: List[Message] = Field(..., description="List of conversation messages")
    max_tokens: int = Field(default=2048, gt=0, le=4096, description="Maximum number of tokens to generate")
    temperature: float = Field(default=0.5, gt=0, le=2.0, description="Temperature for sampling")
    top_p: float = Field(default=0.7, gt=0, le=1.0, description="Top-p sampling parameter")

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize Hugging Face client
hf_client = InferenceClient(
    model=os.getenv("MODEL_ID", "mistralai/Mistral-Nemo-Instruct-2407"), # default model added to client
    token=os.getenv("HF_TOKEN"), # renamed api_key to token
    timeout=30
)

MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"

async def generate_stream(messages: List[Message], max_tokens: int, temperature: float, top_p: float):
    """Generate streaming response using Hugging Face Inference API."""
    try:
        # Convert messages to the format expected by the API
        formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]

        # Stream the response chunks
        for chunk in hf_client.text_generation(
            prompt= formatted_messages, # updated to text_generation 
            details=True,
            max_new_tokens=max_tokens, # renamed max_tokens to max_new_tokens
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            stream=True,
        ):
            if chunk.token.text is not None:
                yield chunk.token.text

    except Exception as e:
        logger.error(f"Error in generate_stream: {e}", exc_info=True)
        raise ValueError(f"Error generating response: {e}")

@app.post("/generate")
async def chat_stream(input: ChatInput, request: Request):
    """Stream chat completions based on the input messages."""
    try:
        if not os.getenv("HF_TOKEN"):
            raise HTTPException(
                status_code=500,
                detail="HF_TOKEN environment variable not set"
            )

        logger.info(f"Received chat request from {request.client.host}")
        logger.info(f"Number of messages: {len(input.messages)}")

        return StreamingResponse(
            generate_stream(
                messages=input.messages,
                max_tokens=input.max_tokens,
                temperature=input.temperature,
                top_p=input.top_p
            ),
            media_type="text/event-stream"
        )
    except Exception as e:
        logger.error(f"Error in chat_stream endpoint: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    """Root endpoint that returns a welcome message."""
    return {
        "message": "Welcome to the Hugging Face Inference API Streaming Chat!",
        "status": "running",
        "model": MODEL_ID
    }

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model": MODEL_ID,
        "hf_token_set": bool(os.getenv("HF_TOKEN"))
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")