Manasa1 commited on
Commit
7062268
·
verified ·
1 Parent(s): e90cbf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -23
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import gradio as gr
2
- from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
3
- import re
4
 
5
- # Load the fine-tuned GPT-2 model and tokenizer
6
- model_dir = "Manasa1/finetuned_GPT23"
7
  fine_tuned_model = GPT2LMHeadModel.from_pretrained(model_dir)
8
  fine_tuned_tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
9
 
10
- # Create a text-generation pipeline
11
  generator = pipeline('text-generation', model=fine_tuned_model, tokenizer=fine_tuned_tokenizer)
12
 
13
  # Function to intelligently add relevant hashtags and emojis
@@ -38,18 +37,26 @@ def add_relevant_tags(tweet, input_question):
38
  hashtags = " ".join(topic_to_hashtags[topic][:2]) # Take up to 2 hashtags
39
  emoji = topic_to_emojis[topic]
40
  tweet = f"{tweet} {emoji} {hashtags}"
 
 
 
 
41
  return tweet.strip()
42
 
 
43
  def generate_tweet(input_question):
44
- # Format the input without "Question:" and "Answer:"
45
- prompt = input_question.strip()
46
-
47
- # Generate the output with a max length of around 200 to 280 characters
48
- output = generator(prompt, max_length=300, num_return_sequences=1, temperature=0.7, top_p=0.9)
49
-
50
  # Extract the generated text
51
  tweet = output[0]['generated_text']
52
-
 
 
 
53
  # Ensure the tweet is between 200 and 280 characters
54
  tweet_length = len(tweet)
55
  if tweet_length > 280:
@@ -58,21 +65,26 @@ def generate_tweet(input_question):
58
  if last_period != -1:
59
  tweet = tweet[:last_period + 1]
60
  elif tweet_length < 200:
61
- tweet = tweet.ljust(200)
62
 
63
  # Add relevant hashtags and emojis
64
- tweet = add_relevant_tags(tweet, input_question)
 
 
65
 
 
 
 
66
  return tweet
67
 
68
- # Create the Gradio interface
69
- interface = gr.Interface(
70
- fn=generate_tweet,
71
- inputs=gr.Textbox(label="Enter a prompt/question", placeholder="Write a tweet about AI."),
72
- outputs=gr.Textbox(label="Generated Tweet"),
73
- title="Tweet Generator",
74
- description="Generate concise, relevant tweets enriched with appropriate emojis and hashtags using a fine-tuned GPT-2 model."
75
  )
76
 
77
- # Launch the interface
78
- interface.launch()
 
1
  import gradio as gr
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
 
3
 
4
+ # Load the fine-tuned model and tokenizer
5
+ model_dir = "Manasa1/finetuned_GPT2w"
6
  fine_tuned_model = GPT2LMHeadModel.from_pretrained(model_dir)
7
  fine_tuned_tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
8
 
9
+ # Define the generator pipeline
10
  generator = pipeline('text-generation', model=fine_tuned_model, tokenizer=fine_tuned_tokenizer)
11
 
12
  # Function to intelligently add relevant hashtags and emojis
 
37
  hashtags = " ".join(topic_to_hashtags[topic][:2]) # Take up to 2 hashtags
38
  emoji = topic_to_emojis[topic]
39
  tweet = f"{tweet} {emoji} {hashtags}"
40
+ else:
41
+ # If no topic is detected, don't add emojis/hashtags
42
+ tweet = f"{tweet} #NoTopic"
43
+
44
  return tweet.strip()
45
 
46
+ # Function to generate tweet
47
  def generate_tweet(input_question):
48
+ # Formulate the prompt with clear guidance for tweet generation
49
+ input_text = f"Write a very short, engaging tweet with emojis and relevant hashtags about {input_question}. Keep it between 200 and 280 characters. Provide only the tweet."
50
+
51
+ # Generate the output using the pipeline
52
+ output = generator(input_text, max_length=280, num_return_sequences=1, temperature=0.7, top_p=0.9)
53
+
54
  # Extract the generated text
55
  tweet = output[0]['generated_text']
56
+
57
+ # Extract the tweet part by splitting based on the prompt
58
+ tweet = tweet.split(f"Write a very short, engaging tweet with emojis and relevant hashtags about {input_question}")[-1].strip()
59
+
60
  # Ensure the tweet is between 200 and 280 characters
61
  tweet_length = len(tweet)
62
  if tweet_length > 280:
 
65
  if last_period != -1:
66
  tweet = tweet[:last_period + 1]
67
  elif tweet_length < 200:
68
+ tweet = tweet.ljust(200) # Ensure a minimum length of 200 characters
69
 
70
  # Add relevant hashtags and emojis
71
+ tweet = add_relevant_tags(tweet, input_question)
72
+
73
+ return tweet
74
 
75
+ # Gradio interface
76
+ def gradio_interface(input_question):
77
+ tweet = generate_tweet(input_question)
78
  return tweet
79
 
80
+ # Create the Gradio app
81
+ iface = gr.Interface(
82
+ fn=gradio_interface,
83
+ inputs="text",
84
+ outputs="text",
85
+ title="AI Tweet Generator",
86
+ description="Enter a topic, and the model will generate a tweet with relevant hashtags and emojis."
87
  )
88
 
89
+ # Launch the app
90
+ iface.launch()