Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import re | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import pipeline, set_seed | |
import numpy as np | |
# ------------------------------- | |
# 1. Load and clean dataset | |
# ------------------------------- | |
def load_and_prepare_data(): | |
dataset = load_dataset("sentiment140") | |
df = dataset["train"].to_pandas() | |
df.dropna(subset=["text", "sentiment"], inplace=True) | |
df["text_length"] = df["text"].apply(len) | |
df = df[(df["text_length"] >= 5) & (df["text_length"] <= 280)] | |
df["clean_text"] = df["text"].apply(clean_text) | |
return df | |
def clean_text(text): | |
text = text.lower() | |
text = re.sub(r"http\S+", "", text) | |
text = re.sub(r"@\w+", "", text) | |
text = re.sub(r"#\w+", "", text) | |
text = re.sub(r"[^\w\s]", "", text) | |
text = re.sub(r"\s+", " ", text).strip() | |
return text | |
# Load data once | |
df = load_and_prepare_data() | |
sample_df = df.sample(5000, random_state=42).reset_index(drop=True) | |
texts = sample_df["clean_text"].tolist() | |
# ------------------------------- | |
# 2. Load models | |
# ------------------------------- | |
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
generator = pipeline("text-generation", model="distilgpt2") | |
set_seed(42) | |
# ------------------------------- | |
# 3. Helper functions | |
# ------------------------------- | |
def get_top3_similarities(text_input): | |
text_embeddings = embedding_model.encode(texts, show_progress_bar=False) | |
input_embedding = embedding_model.encode([text_input]) | |
similarities = cosine_similarity(input_embedding, text_embeddings)[0] | |
top_indices = similarities.argsort()[-3:][::-1] | |
return [texts[i] for i in top_indices] | |
def generate_best_tweet(text_input): | |
synthetic_outputs = generator( | |
text_input, | |
max_length=50, | |
num_return_sequences=10, | |
do_samp_ | |