SS.GPT3M / app.py
aamos's picture
Update app.py
4627671 verified
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from transformers import GPT2LMHeadModel, GPT2Config
import torch
app = FastAPI()
quantized_model_path = "gpt3_mini_quantized_2x_16bits.pth"
config = GPT2Config.from_pretrained("Deniskin/gpt3_medium") # Load the config to initialize the model architecture
quantized_model = GPT2LMHeadModel(config=config) # Initialize the model
# Set the model to evaluation mode
model.eval()
# Function to generate text using the model
def generate_text(prompt):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = model.generate(input_ids, max_length=50, num_return_sequences=1)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>LSTM Text Generation</title>
</head>
<body>
<h1>LSTM Text Generation</h1>
<form id="text-form">
<label for="user-input">Enter your input:</label><br>
<textarea id="user-input" name="user-input" rows="4" cols="50"></textarea><br>
<button type="submit">Generate Text</button>
</form>
<div id="output"></div>
<script>
document.getElementById("text-form").addEventListener("submit", function(event) {
event.preventDefault();
var userInput = document.getElementById("user-input").value;
fetch("/generate", {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({ input_text: userInput })
})
.then(response => response.json())
.then(data => {
document.getElementById("output").innerText = data.generated_text;
});
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=200)
@app.post("/generate")
async def generate(request: Request, input_text: str = Form(...)):
generated_text = generate_text(input_text)
return {"generated_text": generated_text}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)