Spaces:
Sleeping
Sleeping
File size: 3,164 Bytes
26532db 6f13b9a c2c3e4f 7062268 6f13b9a 386fde5 6f13b9a 25f67ed 6f13b9a c685a0d 6f13b9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import re
# Load the fine-tuned model and tokenizer
try:
model = GPT2LMHeadModel.from_pretrained("Manasa1/finetuned_distillGPT2") # Path to your fine-tuned GPT-2 model
tokenizer = GPT2Tokenizer.from_pretrained("Manasa1/finetuned_distillGPT2") # Path to tokenizer
tokenizer.pad_token = tokenizer.eos_token # Ensure pad_token is set correctly
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
exit()
# Function to generate an answer to a question
def generate_answer(question):
if not question.strip():
return "Error: Question cannot be empty."
try:
prompt = f"Q: {question} A:"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
prompt_length = len(inputs["input_ids"][0])
max_new_tokens = 1024 - prompt_length
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
num_return_sequences=1,
no_repeat_ngram_size=2,
top_p=0.9,
top_k=50,
temperature=0.6,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
answer = tokenizer.decode(output[0], skip_special_tokens=True)
return answer[len(prompt):].strip() if answer else "Error: Could not generate a meaningful response."
except Exception as e:
return f"Error during generation: {e}"
# Function to add relevant hashtags and emojis
def add_hashtags_and_emojis(tweet):
hashtags_and_emojis = {
"AI": ["#AI", "๐ค"],
"machine learning": ["#MachineLearning", "๐"],
"data": ["#DataScience", "๐"],
"technology": ["#Tech", "๐ป"],
"innovation": ["#Innovation", "โจ"],
"coding": ["#Coding", "๐จโ๐ป"],
"future": ["#Future", "๐ฎ"],
"startup": ["#Startup", "๐"],
"sustainability": ["#Sustainability", "๐ฑ"],
}
tweet_lower = tweet.lower()
added_items = []
for keyword, items in hashtags_and_emojis.items():
if keyword in tweet_lower:
added_items.extend(items)
added_items = list(dict.fromkeys(added_items))
return tweet.strip() + " " + " ".join(added_items)
# Function to handle Gradio input and output
def generate_tweet_with_hashtags(question):
generated_tweet = generate_answer(question)
final_tweet = add_hashtags_and_emojis(generated_tweet)
return final_tweet
# Gradio app
with gr.Blocks() as app:
gr.Markdown("# AI Tweet Generator with Hashtags and Emojis")
gr.Markdown("Enter a question or topic, and the app will generate a tweet and enhance it with relevant hashtags and emojis!")
question_input = gr.Textbox(label="Enter your question or topic:")
output_tweet = gr.Textbox(label="Generated Tweet with Hashtags and Emojis:", interactive=False)
generate_button = gr.Button("Generate Tweet")
generate_button.click(generate_tweet_with_hashtags, inputs=[question_input], outputs=[output_tweet])
# Run the app
if __name__ == "__main__":
app.launch()
|