frantics-bot / server.py
TeePoat's picture
Update server.py
f7a548c verified
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from models.transformer.text_generator import TextGenerator
import os
app = FastAPI()
generator = TextGenerator(
model_name="frantics-bot-model",
)
security = HTTPBearer()
BEARER_TOKEN = os.environ.get("BEARER_TOKEN")
async def validate_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials == BEARER_TOKEN:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
class Message(BaseModel):
author: str
content: str
@app.post("/generate")
def generate_response(
message: Message,
token: str = Depends(validate_token)
):
response = generator.generate_text(
author=message.author,
input_str=message.content,
max_length=100,
num_return_sequences=1,
do_sample=True,
temperature=0.8,
top_k=100,
top_p=0.95
)["generated_texts"][0]
response = response[:response.find("</s>")]
return {"response": response}