File size: 2,542 Bytes
be89f53
 
4627671
 
ce297d6
be89f53
ce297d6
4627671
 
 
ce297d6
 
 
 
 
 
 
 
 
 
 
be89f53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce297d6
be89f53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce297d6
 
be89f53
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from transformers import GPT2LMHeadModel, GPT2Config
import torch

app = FastAPI()

quantized_model_path = "gpt3_mini_quantized_2x_16bits.pth"
config = GPT2Config.from_pretrained("Deniskin/gpt3_medium")  # Load the config to initialize the model architecture
quantized_model = GPT2LMHeadModel(config=config)  # Initialize the model

# Set the model to evaluation mode
model.eval()

# Function to generate text using the model
def generate_text(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    output = model.generate(input_ids, max_length=50, num_return_sequences=1)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>LSTM Text Generation</title>
    </head>
    <body>
        <h1>LSTM Text Generation</h1>
        <form id="text-form">
            <label for="user-input">Enter your input:</label><br>
            <textarea id="user-input" name="user-input" rows="4" cols="50"></textarea><br>
            <button type="submit">Generate Text</button>
        </form>
        <div id="output"></div>

        <script>
            document.getElementById("text-form").addEventListener("submit", function(event) {
                event.preventDefault();
                var userInput = document.getElementById("user-input").value;

                fetch("/generate", {
                    method: "POST",
                    headers: {
                        "Content-Type": "application/json"
                    },
                    body: JSON.stringify({ input_text: userInput })
                })
                .then(response => response.json())
                .then(data => {
                    document.getElementById("output").innerText = data.generated_text;
                });
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content, status_code=200)

@app.post("/generate")
async def generate(request: Request, input_text: str = Form(...)):
    generated_text = generate_text(input_text)
    return {"generated_text": generated_text}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)