Upload app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,9 @@ from transformer import GPT, GPTConfig # Import your model class
|
|
8 |
# Load the model from Hugging Face Hub
|
9 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
def load_model_from_hf():
|
11 |
-
|
12 |
-
model_id = "sudhakar272/shakespheretextgenerator"
|
13 |
-
checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt")
|
14 |
|
15 |
-
checkpoint = torch.load(checkpoint_path
|
16 |
config = checkpoint['config']
|
17 |
model = GPT(config)
|
18 |
model.load_state_dict(checkpoint['model_state_dict'])
|
@@ -73,8 +71,8 @@ iface = gr.Interface(
|
|
73 |
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
|
74 |
],
|
75 |
outputs=gr.Textbox(label="Generated Text"),
|
76 |
-
title="
|
77 |
-
description="Enter text for
|
78 |
examples=[
|
79 |
["To be, or not to be: that is the question.", 100, 1],
|
80 |
["Love all, trust a few, do wrong to none.", 60, 2],
|
|
|
8 |
# Load the model from Hugging Face Hub
|
9 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
def load_model_from_hf():
|
11 |
+
checkpoint_path ="./transformer_model.pt"
|
|
|
|
|
12 |
|
13 |
+
checkpoint = torch.load(checkpoint_path)
|
14 |
config = checkpoint['config']
|
15 |
model = GPT(config)
|
16 |
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
71 |
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
|
72 |
],
|
73 |
outputs=gr.Textbox(label="Generated Text"),
|
74 |
+
title="Shakesphere Text Generator",
|
75 |
+
description="Enter text for Shakesphere way of text and continue the same",
|
76 |
examples=[
|
77 |
["To be, or not to be: that is the question.", 100, 1],
|
78 |
["Love all, trust a few, do wrong to none.", 60, 2],
|