FastAPIExample / main.py
Akbartus's picture
Update main.py
a95c82c verified
raw
history blame
515 Bytes
import transformers
import torch
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
app = FastAPI()
model_id = "meta-llama/Meta-Llama-3-8B"
pipeline = transformers.pipeline(
"text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
)
class Item(BaseModel):
prompt: str
def generate(item: Item):
pipeline(item.prompt)
@app.post("/generate/")
async def generate_text(item: Item):
return {"response": generate(item)}