FastAPIExample / main.py
Akbartus's picture
Update main.py
5e9db5b verified
raw
history blame
448 Bytes
from transformers import pipeline, set_seed
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
app = FastAPI()
class Item(BaseModel):
prompt: str
generator = pipeline('text-generation', model='gpt2')
set_seed(42)
def generate(item: Item):
generator(item.prompt, max_length=30, num_return_sequences=5)
@app.post("/generate/")
async def generate_text(item: Item):
return {"response": generate(item)}