aamos commited on
Commit
be89f53
·
verified ·
1 Parent(s): 363647d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -16
app.py CHANGED
@@ -1,13 +1,13 @@
1
- from flask import Flask, render_template, request, jsonify
2
- import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
 
5
- app = Flask(__name__)
6
 
7
  # Load the LSTM-based language model
8
- model_path = "gpt3_mini_quantized_2x_16bits.pth"
9
- tokenizer = GPT2Tokenizer.from_pretrained("Deniskin/gpt3_medium")
10
- model = GPT2LMHeadModel.from_pretrained("Deniskin/gpt3_medium")
11
  model.load_state_dict(torch.load(model_path))
12
 
13
  # Set the model to evaluation mode
@@ -20,16 +20,54 @@ def generate_text(prompt):
20
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
21
  return generated_text
22
 
23
- @app.route("/")
24
- def home():
25
- return render_template("index.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- @app.route("/generate", methods=["POST"])
28
- def generate():
29
- data = request.json
30
- user_input = data["input_text"]
31
- generated_text = generate_text(user_input)
32
- return jsonify({"generated_text": generated_text})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  if __name__ == "__main__":
35
- app.run(debug=True)
 
 
 
1
+ from fastapi import FastAPI, Request, Form
2
+ from fastapi.responses import HTMLResponse
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
 
5
+ app = FastAPI()
6
 
7
  # Load the LSTM-based language model
8
+ model_path = "your_model.pth" # Replace with your model path
9
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
11
  model.load_state_dict(torch.load(model_path))
12
 
13
  # Set the model to evaluation mode
 
20
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
21
  return generated_text
22
 
23
+ @app.get("/", response_class=HTMLResponse)
24
+ async def home(request: Request):
25
+ html_content = """
26
+ <!DOCTYPE html>
27
+ <html lang="en">
28
+ <head>
29
+ <meta charset="UTF-8">
30
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
31
+ <title>LSTM Text Generation</title>
32
+ </head>
33
+ <body>
34
+ <h1>LSTM Text Generation</h1>
35
+ <form id="text-form">
36
+ <label for="user-input">Enter your input:</label><br>
37
+ <textarea id="user-input" name="user-input" rows="4" cols="50"></textarea><br>
38
+ <button type="submit">Generate Text</button>
39
+ </form>
40
+ <div id="output"></div>
41
 
42
+ <script>
43
+ document.getElementById("text-form").addEventListener("submit", function(event) {
44
+ event.preventDefault();
45
+ var userInput = document.getElementById("user-input").value;
46
+
47
+ fetch("/generate", {
48
+ method: "POST",
49
+ headers: {
50
+ "Content-Type": "application/json"
51
+ },
52
+ body: JSON.stringify({ input_text: userInput })
53
+ })
54
+ .then(response => response.json())
55
+ .then(data => {
56
+ document.getElementById("output").innerText = data.generated_text;
57
+ });
58
+ });
59
+ </script>
60
+ </body>
61
+ </html>
62
+ """
63
+ return HTMLResponse(content=html_content, status_code=200)
64
+
65
+ @app.post("/generate")
66
+ async def generate(request: Request, input_text: str = Form(...)):
67
+ generated_text = generate_text(input_text)
68
+ return {"generated_text": generated_text}
69
 
70
  if __name__ == "__main__":
71
+ import uvicorn
72
+ uvicorn.run(app, host="0.0.0.0", port=8000)
73
+