Manasa1 commited on
Commit
d62ca5c
·
verified ·
1 Parent(s): a78a40d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load pre-trained model (or fine-tuned model)
5
- model_name = "Manasa1/GPT_Finetuned_tweets" # Replace with the fine-tuned model name
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
@@ -13,8 +13,17 @@ def generate_tweet(input_text):
13
  "Ensure the response is concise, engaging, and suitable for a diverse audience on social media. "
14
  "Incorporate elements of thought leadership, futuristic perspectives, and practical wisdom where appropriate.").format(input_text)
15
 
16
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
17
- outputs = model.generate(inputs['input_ids'], max_length=280, num_return_sequences=1, top_p=0.95, top_k=50)
 
 
 
 
 
 
 
 
 
18
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
 
20
  # Extract the tweet text (exclude prompt if included)
@@ -41,4 +50,4 @@ def main():
41
  # Run Gradio app
42
  if __name__ == "__main__":
43
  app = main()
44
- app.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load pre-trained model (or fine-tuned model)
5
+ model_name = "/kaggle/working/gpt-finetuned-qa" # Replace with the fine-tuned model name
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
 
13
  "Ensure the response is concise, engaging, and suitable for a diverse audience on social media. "
14
  "Incorporate elements of thought leadership, futuristic perspectives, and practical wisdom where appropriate.").format(input_text)
15
 
16
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
17
+ outputs = model.generate(
18
+ inputs['input_ids'],
19
+ attention_mask=inputs['attention_mask'],
20
+ max_length=280,
21
+ num_return_sequences=1,
22
+ top_p=0.95,
23
+ top_k=50,
24
+ do_sample=True,
25
+ pad_token_id=tokenizer.pad_token_id
26
+ )
27
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
29
  # Extract the tweet text (exclude prompt if included)
 
50
  # Run Gradio app
51
  if __name__ == "__main__":
52
  app = main()
53
+ app.launch(share=True)