Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import util | |
| class Results(): | |
| def __init__(self, model, tweet_file, n_tweets, narratives): | |
| self.model = model | |
| self.read_media(tweet_file) | |
| if n_tweets > len(self.df): | |
| n_tweets = len(self.df) | |
| print("Number of tweets selected is greater than total.\ | |
| \nContinuing with {num} tweets total.".format(num=n_tweets)) | |
| self.n_tweets = n_tweets | |
| # TODO!! | |
| # self.narratives = ["Russia is an ally", "the 2020 election was stolen"] | |
| self.narratives = narratives | |
| self.similarities = np.empty((self.n_tweets, len(self.narratives))) | |
| self.tweets = pd.DataFrame(columns=["Tweet", "Sim_Index"]) | |
| # This should really be called by the user but what good is the results class if it has none ? | |
| self.get_results() | |
| def read_media(self, file): | |
| if file.endswith(".csv"): | |
| self.df = pd.read_csv(file, encoding='utf-8', encoding_errors='ignore') | |
| elif file.endswith(".txt"): | |
| with open(file, "r") as f: | |
| content = f.read() | |
| self.df = pd.DataFrame({"Tweet": [content]}) | |
| def embed_narratives(self, narratives): | |
| nar_embeds = [] | |
| for narrative in narratives: | |
| nar_embeds.append(self.model.encode(narrative, convert_to_tensor=True)) | |
| return nar_embeds | |
| def get_results(self): | |
| nar_embeds = self.embed_narratives(self.narratives) | |
| for i, tweet in enumerate(self.df["Tweet"][:self.n_tweets]): | |
| embedding = self.model.encode(tweet, convert_to_tensor=True) | |
| tweet_sims = np.empty(len(nar_embeds)) | |
| for j, nar_embed in enumerate(nar_embeds): | |
| sim = util.pytorch_cos_sim(embedding, nar_embed) | |
| tweet_sims[j] = sim | |
| self.similarities[i] = tweet_sims | |
| self.tweets.loc[i] = {"Tweet" : tweet, "Sim_Index" : i} | |
| def sort_by_narrative(self, narrative_ind): | |
| if narrative_ind > len(self.narratives) - 1: | |
| print("Invalid narrative index. Continuing with narrative_ind=0...") | |
| narrative_ind = 0 | |
| # Grab the column of the narrative_ind | |
| narr_sims = self.similarities.T[narrative_ind] # after we transpose, we have 3 rows and n_tweets cols | |
| sorted_args = np.argsort(narr_sims) | |
| # get the sorted sims and the sorted tweets | |
| sorted_sims = narr_sims[sorted_args] | |
| sorted_tweets = self.tweets.iloc[sorted_args] | |
| return sorted_tweets, sorted_sims | |
| def print_top_k(self, k, narrative_ind): | |
| if k > self.n_tweets: | |
| k = self.n_tweets | |
| if narrative_ind > len(self.narratives) - 1: | |
| print("Invalid narrative index. Continuing with narrative_ind=0...") | |
| narrative_ind = 0 | |
| sorted_tweets, sorted_sims = self.sort_by_narrative(narrative_ind) | |
| sorted_tweets["Sims"] = sorted_sims | |
| pd.set_option('display.max_colwidth', None) | |
| print("{k} Most similar tweets to narrative \n\"{narrative}\": \n".format( | |
| k=k, narrative=self.narratives[narrative_ind]), | |
| sorted_tweets[-k:]) | |
| return sorted_tweets[-k:] | |
| def __repr__(self): | |
| return f"First 10 Results: \n {self.tweets[:10]}" |