tweets_clone / app.py
Manasa1's picture
Update app.py
9d629bd verified
raw
history blame
3.12 kB
import gradio as gr
from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
import re
# Load the fine-tuned GPT-2 model and tokenizer
model_dir = "Manasa1/finetuned_GPT23"
fine_tuned_model = GPT2LMHeadModel.from_pretrained(model_dir)
fine_tuned_tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
# Create a text-generation pipeline
generator = pipeline('text-generation', model=fine_tuned_model, tokenizer=fine_tuned_tokenizer)
# Function to dynamically truncate output while keeping it meaningful
def truncate_tweet(tweet, max_length=280):
# Ensure the tweet is concise, ending on a complete sentence
if len(tweet) > max_length:
tweet = tweet[:max_length]
last_period = tweet.rfind(".")
if last_period != -1:
tweet = tweet[:last_period + 1]
return tweet.strip()
# Function to intelligently add relevant hashtags and emojis
def add_relevant_tags(tweet, input_question):
# Pre-defined mappings of topics to hashtags and emojis
topic_to_hashtags = {
"startup": ["#Startups", "#Innovation", "#Entrepreneurship"],
"AI": ["#AI", "#ArtificialIntelligence", "#Tech"],
"technology": ["#Technology", "#Future", "#Tech"],
"future": ["#Future", "#Vision", "#Tech"],
}
topic_to_emojis = {
"startup": "πŸš€",
"AI": "πŸ€–",
"technology": "πŸ’»",
"future": "🌟",
}
# Determine topic from input question (using keywords)
topic = None
for key in topic_to_hashtags.keys():
if key.lower() in input_question.lower():
topic = key
break
# Add relevant hashtags and emoji if a topic is detected
if topic:
hashtags = " ".join(topic_to_hashtags[topic][:2]) # Take up to 2 hashtags
emoji = topic_to_emojis[topic]
tweet = f"{tweet} {emoji} {hashtags}"
return tweet.strip()
def generate_tweet(input_question):
# Format the input without "Question:" and "Answer:"
prompt = input_question.strip()
# Generate the output with a higher max_length for longer responses
output = generator(prompt, max_length=300, num_return_sequences=1, temperature=0.7, top_p=0.9)
# Extract the generated text and clean it
tweet = output[0]['generated_text']
# Remove "Question:" and "Answer:" from the generated text
tweet = re.sub(r"(Question:|Answer:|A:|)", "", tweet).strip()
# Remove any part of the tweet that starts with the input question
tweet = tweet.replace(input_question, "").strip()
# Truncate and add relevant hashtags and emojis
tweet = truncate_tweet(tweet)
tweet = add_relevant_tags(tweet, input_question)
return tweet
# Create the Gradio interface
interface = gr.Interface(
fn=generate_tweet,
inputs=gr.Textbox(label="Enter a prompt/question", placeholder="Write a tweet about AI."),
outputs=gr.Textbox(label="Generated Tweet"),
title="Tweet Generator",
description="Generate concise, relevant tweets enriched with appropriate emojis and hashtags using a fine-tuned GPT-2 model."
)
# Launch the interface
interface.launch()