|
from fastapi import FastAPI, Request, Form |
|
from fastapi.responses import HTMLResponse |
|
from transformers import GPT2LMHeadModel, GPT2Config |
|
import torch |
|
|
|
app = FastAPI() |
|
|
|
quantized_model_path = "gpt3_mini_quantized_2x_16bits.pth" |
|
config = GPT2Config.from_pretrained("Deniskin/gpt3_medium") |
|
quantized_model = GPT2LMHeadModel(config=config) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def generate_text(prompt): |
|
input_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
output = model.generate(input_ids, max_length=50, num_return_sequences=1) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def home(request: Request): |
|
html_content = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>LSTM Text Generation</title> |
|
</head> |
|
<body> |
|
<h1>LSTM Text Generation</h1> |
|
<form id="text-form"> |
|
<label for="user-input">Enter your input:</label><br> |
|
<textarea id="user-input" name="user-input" rows="4" cols="50"></textarea><br> |
|
<button type="submit">Generate Text</button> |
|
</form> |
|
<div id="output"></div> |
|
|
|
<script> |
|
document.getElementById("text-form").addEventListener("submit", function(event) { |
|
event.preventDefault(); |
|
var userInput = document.getElementById("user-input").value; |
|
|
|
fetch("/generate", { |
|
method: "POST", |
|
headers: { |
|
"Content-Type": "application/json" |
|
}, |
|
body: JSON.stringify({ input_text: userInput }) |
|
}) |
|
.then(response => response.json()) |
|
.then(data => { |
|
document.getElementById("output").innerText = data.generated_text; |
|
}); |
|
}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_content, status_code=200) |
|
|
|
@app.post("/generate") |
|
async def generate(request: Request, input_text: str = Form(...)): |
|
generated_text = generate_text(input_text) |
|
return {"generated_text": generated_text} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|