File size: 1,248 Bytes
b238a3c
 
 
 
 
 
 
 
2267429
b238a3c
2267429
b238a3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Initialize FastAPI app
app = FastAPI()

# Load pre-trained DistilGPT-2 model and tokenizer, using from_tf=True to load TensorFlow weights
model_name = "distilgpt2"  # Smaller GPT-2 model
model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True)  # Use from_tf=True
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Pydantic model for request body
class TextRequest(BaseModel):
    text: str

# Route to generate text
@app.post("/generate/")
async def generate_text(request: TextRequest):
    # Encode the input text
    inputs = tokenizer.encode(request.text, return_tensors="pt")

    # Generate a response from the model
    with torch.no_grad():
        outputs = model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_p=0.9, top_k=50)

    # Decode the generated response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"generated_text": response}

# Optionally, you can add a root endpoint for checking server health
@app.get("/")
async def read_root():
    return {"message": "Welcome to the GPT-2 FastAPI server!"}