sudhakar272 commited on
Commit
7992b22
·
verified ·
1 Parent(s): 37be743

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -8,11 +8,9 @@ from transformer import GPT, GPTConfig # Import your model class
8
  # Load the model from Hugging Face Hub
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  def load_model_from_hf():
11
- # Hugging Face model ID (username/model-name)
12
- model_id = "sudhakar272/shakespheretextgenerator"
13
- checkpoint_path = hf_hub_download(repo_id=model_id, filename="transformer_model.pt")
14
 
15
- checkpoint = torch.load(checkpoint_path, map_location=device)
16
  config = checkpoint['config']
17
  model = GPT(config)
18
  model.load_state_dict(checkpoint['model_state_dict'])
@@ -73,8 +71,8 @@ iface = gr.Interface(
73
  gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
74
  ],
75
  outputs=gr.Textbox(label="Generated Text"),
76
- title="Shakespeare-style Text Generator",
77
- description="Enter text for Shakespear way of text continuation",
78
  examples=[
79
  ["To be, or not to be: that is the question.", 100, 1],
80
  ["Love all, trust a few, do wrong to none.", 60, 2],
 
8
  # Load the model from Hugging Face Hub
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  def load_model_from_hf():
11
+ checkpoint_path ="./transformer_model.pt"
 
 
12
 
13
+ checkpoint = torch.load(checkpoint_path)
14
  config = checkpoint['config']
15
  model = GPT(config)
16
  model.load_state_dict(checkpoint['model_state_dict'])
 
71
  gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Samples"),
72
  ],
73
  outputs=gr.Textbox(label="Generated Text"),
74
+ title="Shakesphere Text Generator",
75
+ description="Enter text for Shakesphere way of text and continue the same",
76
  examples=[
77
  ["To be, or not to be: that is the question.", 100, 1],
78
  ["Love all, trust a few, do wrong to none.", 60, 2],