File size: 1,119 Bytes
ce297d6 8af4ec4 ce297d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from flask import Flask, render_template, request, jsonify
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
app = Flask(__name__)
# Load the LSTM-based language model
model_path = "gpt3_mini_quantized_2x_16bits.pth"
tokenizer = GPT2Tokenizer.from_pretrained("Deniskin/gpt3_medium")
model = GPT2LMHeadModel.from_pretrained("Deniskin/gpt3_medium")
model.load_state_dict(torch.load(model_path))
# Set the model to evaluation mode
model.eval()
# Function to generate text using the model
def generate_text(prompt):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = model.generate(input_ids, max_length=50, num_return_sequences=1)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
@app.route("/")
def home():
return render_template("index.html")
@app.route("/generate", methods=["POST"])
def generate():
data = request.json
user_input = data["input_text"]
generated_text = generate_text(user_input)
return jsonify({"generated_text": generated_text})
if __name__ == "__main__":
app.run(debug=True)
|