helperAi / app /main.py
sanmmarr29's picture
Upload 6 files
732e177 verified
raw
history blame
1.82 kB
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from .config import settings
from pydantic import BaseModel
app = FastAPI(
title="Deepseek Chat API",
description="A simple chat API using DeepSeek model",
version="1.0.0"
)
# Mount static files and templates
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME, token=settings.HUGGINGFACE_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
settings.MODEL_NAME,
token=settings.HUGGINGFACE_TOKEN,
torch_dtype=torch.float16,
device_map="auto"
)
class ChatMessage(BaseModel):
message: str
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("chat.html", {"request": request})
@app.post("/chat")
async def chat(message: ChatMessage):
# Prepare the prompt
prompt = f"### Instruction: {message.message}\n\n### Response:"
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the response part
response = response.split("### Response:")[-1].strip()
return {"response": response}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)