File size: 1,119 Bytes
ce297d6
 
 
 
 
 
 
8af4ec4
 
 
ce297d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from flask import Flask, render_template, request, jsonify
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

app = Flask(__name__)

# Load the LSTM-based language model
model_path = "gpt3_mini_quantized_2x_16bits.pth"
tokenizer = GPT2Tokenizer.from_pretrained("Deniskin/gpt3_medium")
model = GPT2LMHeadModel.from_pretrained("Deniskin/gpt3_medium")
model.load_state_dict(torch.load(model_path))

# Set the model to evaluation mode
model.eval()

# Function to generate text using the model
def generate_text(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    output = model.generate(input_ids, max_length=50, num_return_sequences=1)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

@app.route("/")
def home():
    return render_template("index.html")

@app.route("/generate", methods=["POST"])
def generate():
    data = request.json
    user_input = data["input_text"]
    generated_text = generate_text(user_input)
    return jsonify({"generated_text": generated_text})

if __name__ == "__main__":
    app.run(debug=True)