import pandas as pd import numpy as np from sentence_transformers import util class Results(): def __init__(self, model, tweet_file, n_tweets, narratives): self.model = model self.read_media(tweet_file) if n_tweets > len(self.df): n_tweets = len(self.df) print("Number of tweets selected is greater than total.\ \nContinuing with {num} tweets total.".format(num=n_tweets)) self.n_tweets = n_tweets # TODO!! # self.narratives = ["Russia is an ally", "the 2020 election was stolen"] self.narratives = narratives self.similarities = np.empty((self.n_tweets, len(self.narratives))) self.tweets = pd.DataFrame(columns=["Tweet", "Sim_Index"]) # This should really be called by the user but what good is the results class if it has none ? self.get_results() def read_media(self, file): if file.endswith(".csv"): self.df = pd.read_csv(file, encoding='utf-8', encoding_errors='ignore') elif file.endswith(".txt"): with open(file, "r") as f: content = f.read() self.df = pd.DataFrame({"Tweet": [content]}) def embed_narratives(self, narratives): nar_embeds = [] for narrative in narratives: nar_embeds.append(self.model.encode(narrative, convert_to_tensor=True)) return nar_embeds def get_results(self): nar_embeds = self.embed_narratives(self.narratives) for i, tweet in enumerate(self.df["Tweet"][:self.n_tweets]): embedding = self.model.encode(tweet, convert_to_tensor=True) tweet_sims = np.empty(len(nar_embeds)) for j, nar_embed in enumerate(nar_embeds): sim = util.pytorch_cos_sim(embedding, nar_embed) tweet_sims[j] = sim self.similarities[i] = tweet_sims self.tweets.loc[i] = {"Tweet" : tweet, "Sim_Index" : i} def sort_by_narrative(self, narrative_ind): if narrative_ind > len(self.narratives) - 1: print("Invalid narrative index. Continuing with narrative_ind=0...") narrative_ind = 0 # Grab the column of the narrative_ind narr_sims = self.similarities.T[narrative_ind] # after we transpose, we have 3 rows and n_tweets cols sorted_args = np.argsort(narr_sims) # get the sorted sims and the sorted tweets sorted_sims = narr_sims[sorted_args] sorted_tweets = self.tweets.iloc[sorted_args] return sorted_tweets, sorted_sims def print_top_k(self, k, narrative_ind): if k > self.n_tweets: k = self.n_tweets if narrative_ind > len(self.narratives) - 1: print("Invalid narrative index. Continuing with narrative_ind=0...") narrative_ind = 0 sorted_tweets, sorted_sims = self.sort_by_narrative(narrative_ind) sorted_tweets["Sims"] = sorted_sims pd.set_option('display.max_colwidth', None) print("{k} Most similar tweets to narrative \n\"{narrative}\": \n".format( k=k, narrative=self.narratives[narrative_ind]), sorted_tweets[-k:]) return sorted_tweets[-k:] def __repr__(self): return f"First 10 Results: \n {self.tweets[:10]}"