from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Initialize FastAPI app app = FastAPI() # Load pre-trained DistilGPT-2 model and tokenizer, using from_tf=True to load TensorFlow weights model_name = "distilgpt2" # Smaller GPT-2 model model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True) # Use from_tf=True tokenizer = AutoTokenizer.from_pretrained(model_name) # Pydantic model for request body class TextRequest(BaseModel): text: str # Route to generate text @app.post("/generate/") async def generate_text(request: TextRequest): # Encode the input text inputs = tokenizer.encode(request.text, return_tensors="pt") # Generate a response from the model with torch.no_grad(): outputs = model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_p=0.9, top_k=50) # Decode the generated response response = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": response} # Optionally, you can add a root endpoint for checking server health @app.get("/") async def read_root(): return {"message": "Welcome to the GPT-2 FastAPI server!"}