File size: 1,819 Bytes
732e177
 
 
 
071ac3e
 
 
732e177
c660b8d
 
071ac3e
 
c660b8d
 
 
732e177
 
 
 
071ac3e
 
 
 
 
 
 
 
c660b8d
732e177
 
 
 
 
 
 
c660b8d
732e177
071ac3e
732e177
071ac3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c660b8d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from .config import settings
from pydantic import BaseModel

app = FastAPI(
    title="Deepseek Chat API",
    description="A simple chat API using DeepSeek model",
    version="1.0.0"
)

# Mount static files and templates
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")

# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME, token=settings.HUGGINGFACE_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    settings.MODEL_NAME,
    token=settings.HUGGINGFACE_TOKEN,
    torch_dtype=torch.float16,
    device_map="auto"
)

class ChatMessage(BaseModel):
    message: str

@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    return templates.TemplateResponse("chat.html", {"request": request})

@app.post("/chat")
async def chat(message: ChatMessage):
    # Prepare the prompt
    prompt = f"### Instruction: {message.message}\n\n### Response:"
    
    # Generate response
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the response part
    response = response.split("### Response:")[-1].strip()
    
    return {"response": response}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)