|
from flask import Flask, render_template, request, jsonify |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
model_path = "gpt3_mini_quantized_2x_16bits.pth" |
|
tokenizer = GPT2Tokenizer.from_pretrained("Deniskin/gpt3_medium") |
|
model = GPT2LMHeadModel.from_pretrained("Deniskin/gpt3_medium") |
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def generate_text(prompt): |
|
input_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
output = model.generate(input_ids, max_length=50, num_return_sequences=1) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
@app.route("/") |
|
def home(): |
|
return render_template("index.html") |
|
|
|
@app.route("/generate", methods=["POST"]) |
|
def generate(): |
|
data = request.json |
|
user_input = data["input_text"] |
|
generated_text = generate_text(user_input) |
|
return jsonify({"generated_text": generated_text}) |
|
|
|
if __name__ == "__main__": |
|
app.run(debug=True) |
|
|