File size: 5,093 Bytes
8e1c327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from generate_narratives import Narrative_Generator
from sentence_transformers import SentenceTransformer
import pandas as pd
from mlx_lm import load 
from sim_scores import Results
from flask import Flask, jsonify

# Global variables to store results
formatted_narratives = None
clustered_tweets = None

dataset_options = {
    "Trump Tweets": "trumptweets1205-127.csv",
    "Dataset 2": "path/to/dataset2.csv",
    "Upload Your Own": None
}

narrative_options = {
    "0": "Russia is an ally",
    "1": "The 2020 election was stolen"
}

summary_model, tokenizer = load("mlx-community/Mistral-Nemo-Instruct-2407-4bit")
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

def run_narrative_generation(selected_dataset, uploaded_file, num_narratives, progress=gr.Progress()):
    progress(0, desc="Generating narratives...")
    file_path = dataset_options.get(selected_dataset, None)
    if selected_dataset == "Upload Your Own":
        if uploaded_file is None:
            return "Please upload a file.", None
        file_path = uploaded_file.name

    generator = Narrative_Generator(summary_model, tokenizer, embedding_model, file_path, num_narratives)
    try:
        json_narratives, _, clustered_tweets = generator.generate_narratives(progress=progress)
        formatted_narratives = generator.get_html_formatted_outputs(json_narratives)

        html_output = "<div id='narratives-container' style='display: flex; flex-direction: column; gap: 10px;'>"
        for narrative_html in formatted_narratives:
            html_output += narrative_html 
        html_output += "</div>"
        # Store the results globally

        formatted_narratives = json_narratives 
        clustered_tweets = clustered_tweets

        return html_output, clustered_tweets
    # Show error in UI
    except ValueError as e:
        return str(e), None 


def trace_narrative(selected_dataset, narrative, max_tweets):
    """
    Runs the ranking function and returns the top k similar tweets to the selected narrative.
    """
    # Make sure the file is handled correctly
    # file_path = "trumptweets1205-127.csv"
    file_path = dataset_options.get(selected_dataset, None)

    # Load the Results object and rank the narrative
    results = Results(embedding_model, file_path, max_tweets, [narrative])
    top_k_results = results.print_top_k(k=10, narrative_ind=0)

    # Format the result as a readable string or table
    formatted_output = f"**Top 10 Most Similar Tweets to Narrative**: `{narrative}`\n\n"
    
    if isinstance(top_k_results, pd.DataFrame):
        formatted_output += top_k_results.to_markdown(index=False)
    return formatted_output

# Gradio UI Setup
with gr.Blocks() as iface:
    with gr.Tab("Narrative Generator"):
        narrative_generator_interface = gr.Interface(
            fn=run_narrative_generation,
            inputs=[
                gr.Dropdown(list(dataset_options.keys()), value="Trump Tweets", label="Select Dataset"),
                gr.File(label="Upload Text File (Optional)"),
                gr.Slider(1, 10, step=1, value=5, label="Number of Narratives"),
            ],
            outputs=[
                gr.Markdown(label="Generated Narratives"),
            ],
            title="Trump Narrative Generator",
            description="Choose a dataset or upload a file. Select the number of narratives, then click 'Run'.",
            theme="huggingface"
        )

    with gr.Tab("Trace Narrative"):
        trace_narrative_interface = gr.Interface(
            fn=trace_narrative,
            inputs=[
                # gr.Dropdown([["Russia is an ally", "The 2020 election was stolen"]], value="Narrative", label="Select Narrative"),
                gr.Dropdown(list(dataset_options.keys()), value="Trump Tweets", label="Select Dataset"),
                gr.Textbox(label="Narrative", value="e.g. Russia is an ally"),
                gr.Slider(1, 1000, step=1, value=10, label="Max Tweets to Process"),
            ],
            outputs=[
                gr.HTML(label="Top 10 Similar Tweets"),
                gr.JSON(label="Clustered Tweets")    
            ],
            title="Rank Narrative with Tweets",
            description="Select a narrative and trace the most similar tweets to that narrative.",
            theme="huggingface"
        )

    def get_narratives():
        if formatted_narratives is not None and clustered_tweets is not None:
            return jsonify({
                "generated_narratives": formatted_narratives,  # This is the narratives output from the function
                "clustered_tweets": clustered_tweets           # This is the corresponding tweets list
            })
        else:
            return jsonify({"error": "No narratives generated yet"}), 400

    gr.HTML("""
        <link rel="stylesheet" type="text/css" href="assets/style.css">
        <script src="assets/script.js"></script>
        """)

# iface.launch(share=False, inbrowser=True, server_name="localhost", server_port=7860, prevent_thread_lock=True)
iface.launch()