bevelapi / main.py
BeveledCube's picture
Increased model size to
bc27fb1 verified
raw
history blame
1.25 kB
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "facebook/blenderbot-1B-distill"
# https://huggingface.co/models?sort=trending&search=facebook%2Fblenderbo
# facebook/blenderbot-3B
# facebook/blenderbot-1B-distill
# facebook/blenderbot-400M-distill
# facebook/blenderbot-90M
# facebook/blenderbot_small-90M
app = FastAPI()
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class req(BaseModel):
prompt: str
@app.get("/")
def read_root():
return FileResponse(path="templates/index.html", media_type="text/html")
@app.post("/api")
def read_root(data: req):
print("Prompt:", data.prompt)
input_text = data.prompt
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate output using the model
output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data