File size: 11,080 Bytes
8e1c327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import pandas as pd
import numpy as np
from mlx_lm import generate
from sklearn.cluster import KMeans
# from sentence_transformers import SentenceTransformer, util
from transformers import pipeline

from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field
from langchain_huggingface import HuggingFacePipeline
from langchain_community.llms.mlx_pipeline import MLXPipeline
# from langchain_core.exceptions import OutputParserException
# from langchain.output_parsers import OutputFixingParser
# from langchain_openai import ChatOpenAI
# from langchain_ollama import ChatOllama
import json
import re
from dotenv import load_dotenv

load_dotenv()
# Simple prompt
SYS_PROMPT = "You should find the top two dominant narratives in the following batch of tweets. Do not cite which tweets correspond to the narratives, just supply the narrative summaries. You must always return valid JSON fenced by a markdown code block. Do not return any additional text. "
# OUTPUT_PARSE_PROMPT = "Structure your response as a JSON object with {'narrative 1': value, 'narrative 2', value}"

smallest_batch_size = 10

class Narrative_Generator():
    def __init__(self, summary_model, tokenizer, embedding_model, file, num_narratives):
        self.summary_model = summary_model
        self.tokenizer = tokenizer
        self.embedding_model = embedding_model
        self.num_narratives = num_narratives
        
        # Try loading the file, handle errors gracefully
        try:
            self.df = pd.read_csv(file, encoding="utf-8", encoding_errors="ignore")
        except FileNotFoundError:
            raise ValueError(f"Error: The file '{file}' was not found.")
        except pd.errors.EmptyDataError:
            raise ValueError(f"Error: The file '{file}' is empty or corrupted.")
        except Exception as e:
            raise ValueError(f"Error loading file '{file}': {e}")
        # You could use preprocess_context_window here instead if data is too big...

    def preprocess_context_window(self, file):
        if file.endswith(".txt"):
            with open(file, "r") as f:
                #TODO batch tweets together somehow -- append previously found narratives to the new batch of tweets? Largest batch possible? very small frequent chunks? Enough to get a general sense of thoughts? Very small then aggregate up multiple times?
                pass
        elif file.endswith(".csv"):
            df = pd.read_csv(file, encoding='utf-8', encoding_errors='ignore')
            # Create n_tweets / smallest_batch_size chunks of tweets / data
            chunks = self.chunk_it(df, smallest_batch_size)
            return chunks

    def chunk_it(self, array, k):
        return np.array_split(array, np.ceil(len(array) / k).astype(int))

    def create_prompt(self, tokenizer, user_prompt):
        messages = [
            {"role": "system", "content": SYS_PROMPT},
            {"role": "user", "content": user_prompt},
        ]
        
        prompt = tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )
        return prompt

    def create_format_prompt(self, parser):
        prompt = PromptTemplate(
            template= SYS_PROMPT+"\n{format_instructions}\n{query}\n",
            input_variables=["query"],
            partial_variables={"format_instructions": parser.get_format_instructions()},
        )
        return prompt

    def cluster_embedded_tweets(self, tweets):
        embeddings = []
        for tweet in tweets:
            # TODO this is NOT clustering the embeddings, it's clustering the fucking input ids
            # Why does it seem to work then? TODO visualize the spread of these clusters and the tweets within the clusters
            # Lowkey this can be done later -- having bugs like this that affect the outcome but not the look of the prototype can be addressed later
            # Wait lowkey this might work bc it's the info retrieval semantic sim model, have to check docs on the output of encode
            embeddings.append(self.embedding_model.encode(tweet, convert_to_numpy=True))
        clusters = KMeans(n_clusters=self.num_narratives).fit(embeddings)
        unique_labels = np.unique(clusters.labels_)
        # Chunk tweets by cluster labels
        clustered_tweets = [tweets[clusters.labels_ == label] for label in unique_labels]
        return clustered_tweets

    # TODO yaml
    # TODO private functions
    def process_chunk(self, chunk, model, tokenizer):
        prompt = self.create_prompt(tokenizer, SYS_PROMPT, user_prompt=chunk)
        response = generate(
            model, tokenizer, 
            #temperature=0.9, top_p=0.8, 
            max_tokens=512, prompt=prompt, 
            verbose=True)
        return response, prompt

    def generate_narratives(self, progress=None):
        # Set up a parser + inject instructions into the prompt template.
        parser = JsonOutputParser(pydantic_object=self.NarrativeSummary)
        llm = MLXPipeline(model=self.summary_model, tokenizer=self.tokenizer, pipeline_kwargs={"response_format" :
          {"type": "json_object",}})
        prompt = self.create_format_prompt(parser)
        chain = prompt | llm | self.parse_json_objects 

        # What happens when we have way more than 300 tweets? Can we still cluster 50,000 or do we chunk it by time and regen narratives?
        clustered_tweets = self.cluster_embedded_tweets(self.df["Tweet"])
        responses = []
        if progress != None:
            for chunk in progress.tqdm(clustered_tweets):
                # resp, prompt = process_chunk(chunk, self.summary_model, self.tokenizer)
                resp = chain.invoke({"query": chunk})
                # TODO remove print
                print(resp)
                if not resp:
                    continue
                responses.append(resp[0])
        else:
            for chunk in (clustered_tweets):
                resp = chain.invoke({"query": chunk})
                if not resp:
                    continue
                responses.append(resp[0])
 
        return responses, prompt, clustered_tweets

    def format(self, raw_narratives):
        # Formats to markdown for direct display as a string.
        formatted_output = ""
        # Loop through the outer list (each narrative set)
        for idx, narrative_set in enumerate(raw_narratives, 1):
            if not narrative_set:  # Skip empty lists
                continue

            formatted_output += f"### Narrative Set {idx}\n"

            # Loop through each narrative (which is a dictionary)
            for key, value in narrative_set.items():
                formatted_output += f"- **{key.replace('_', ' ').capitalize()}**: {value}\n"

            formatted_output += "\n---\n\n"

        return formatted_output.strip()


    def get_html_formatted_outputs(self, raw_narratives):
        outputs = []
        for idx, narrative_set in enumerate(raw_narratives, 1):
            if not narrative_set:  # Skip empty lists
                continue
            
            # Start a container div for each narrative pair
            formatted_output = f"<div class='narrative-block' data-narrative-id='{idx}'>"
            formatted_output += "<hr class='narrative-separator'>" 

            for key, value in narrative_set.items():
                formatted_output += f"""
                <div class='narrative-item'>
                    <strong>{key.replace('_', ' ').capitalize()}:</strong>
                    <p>{value}</p>
                </div>
                """

            formatted_output += "</div>"  # Close block div
            outputs.append(formatted_output)

        return outputs


    def save_json_narratives(self, json_list):
        # Save JSON for download
        json_file = "generated_narratives.json"
        with open(json_file, "w") as f:
            json.dump(json_list, f, indent=4)
        return json_file


    def parse_json_objects(self, text):
        """
        Identifies JSON objects between curly braces and attempts to parse them.
        Returns a list of successfully parsed JSON objects.
        Raises an error if parsing any JSON object fails, but continues processing others.
        """
        # Regex to find JSON objects between curly braces {}
        json_objects = re.findall(r'\{.*?\}', text, re.DOTALL)
        
        parsed_json_list = []  # To store successfully parsed JSON objects

        # Loop through each potential JSON object
        for json_str in json_objects:
            try:
                # Try parsing the JSON string
                parsed_json = json.loads(json_str)
                parsed_json_list.append(parsed_json)
            except json.JSONDecodeError as e:
                print(f"Error parsing JSON: {e}. Skipping invalid JSON object.")
                continue  # Skip the invalid JSON and continue processing other objects

        return parsed_json_list


    # def parse_narratives(self, raw_narratives):
    #     """
    #     Parses a list of JSON strings into Python dictionaries
    #     and returns a formatted markdown string + a downloadable JSON file.
    #     """
    #     formatted_output = ""
    #     json_list = []

    #     for i, json_str in enumerate(raw_narratives, 1):
    #         try:
    #             narrative_obj = json.loads(json_str)  # Convert JSON string to dict
    #             json_list.append(narrative_obj)

    #             formatted_output += f"### Narrative Set {i}\n"
    #             for key, value in narrative_obj.items():
    #                 formatted_output += f"- **{key.replace('_', ' ').capitalize()}**: {value}\n"
    #             formatted_output += "\n---\n\n"
    #         except json.JSONDecodeError:
    #             formatted_output += f"⚠️ Error parsing narrative {i}: {json_str}\n\n"

    #     # Save JSON for download
    #     json_file = "generated_narratives.json"
    #     with open(json_file, "w") as f:
    #         json.dump(json_list, f, indent=4)

    #     return formatted_output.strip(), json_file
    
    # Define your desired data structure.
    class NarrativeSummary(BaseModel):
        narrative_1: str = Field(description="Most dominant narrative")
        narrative_2: str = Field(description="Second dominant narrative")

# For tweet in timeline of tweets, sim score to each of X number of generated narratives, and add to list
    # Plot each list of similarities to narrative over time

#TODO LLM Validation: 
    # Have people randomly sample tweets from the batches and agree or disagree with the top 2 narratives presented (and write in alternative if desired)
#TODO Similarity validation
    # Given valid narratives, have people code tweets as belonging to or not belonging to those
# TODO filter out retweets / response tweets? Adjust prompt to account for this? For example, respondign to fake news CNN tweets criticizing his policies are confusing the LLM