Codingxx commited on
Commit
b238a3c
·
verified ·
1 Parent(s): 6e65396

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+
6
+ # Initialize FastAPI app
7
+ app = FastAPI()
8
+
9
+ # Load pre-trained DistilGPT-2 model and tokenizer
10
+ model_name = "distilgpt2" # Smaller GPT-2 model
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ # Pydantic model for request body
15
+ class TextRequest(BaseModel):
16
+ text: str
17
+
18
+ # Route to generate text
19
+ @app.post("/generate/")
20
+ async def generate_text(request: TextRequest):
21
+ # Encode the input text
22
+ inputs = tokenizer.encode(request.text, return_tensors="pt")
23
+
24
+ # Generate a response from the model
25
+ with torch.no_grad():
26
+ outputs = model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_p=0.9, top_k=50)
27
+
28
+ # Decode the generated response
29
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ return {"generated_text": response}
31
+
32
+ # Optionally, you can add a root endpoint for checking server health
33
+ @app.get("/")
34
+ async def read_root():
35
+ return {"message": "Welcome to the GPT-2 FastAPI server!"}