from fastapi import FastAPI, Request from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch from .config import settings from pydantic import BaseModel app = FastAPI( title="Deepseek Chat API", description="A simple chat API using DeepSeek model", version="1.0.0" ) # Mount static files and templates app.mount("/static", StaticFiles(directory="app/static"), name="static") templates = Jinja2Templates(directory="app/templates") # Initialize model and tokenizer tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME, token=settings.HUGGINGFACE_TOKEN) model = AutoModelForCausalLM.from_pretrained( settings.MODEL_NAME, token=settings.HUGGINGFACE_TOKEN, torch_dtype=torch.float16, device_map="auto" ) class ChatMessage(BaseModel): message: str @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("chat.html", {"request": request}) @app.post("/chat") async def chat(message: ChatMessage): # Prepare the prompt prompt = f"### Instruction: {message.message}\n\n### Response:" # Generate response inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the response part response = response.split("### Response:")[-1].strip() return {"response": response} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)