doyouknowmarc commited on
Commit
ae1c4cf
·
verified ·
1 Parent(s): b562876

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import json
4
+ import os
5
+ import csv
6
+ from openpyxl import Workbook
7
+ from openpyxl.styles import PatternFill
8
+ import pandas as pd
9
+ from datetime import datetime
10
+ import time
11
+
12
+ # Function to parse the Ollama API response to JSON
13
+ def parse_response_to_json(response):
14
+ parsed_response = response.json()
15
+ json_response_string = parsed_response['response']
16
+ json_response = json.loads(json_response_string)
17
+ return json_response
18
+
19
+ # Function to process uploaded files using Ollama
20
+ def process_files_with_ollama(uploaded_files, url, model, prompt_template, schema, output_file_path):
21
+ # Create the CSV file and write the header
22
+ with open(output_file_path, mode="w", newline="", encoding="utf-8") as file:
23
+ writer = csv.writer(file)
24
+ writer.writerow(["Input", "Sentiment", "Reasoning"])
25
+
26
+ # Initialize progress bar
27
+ progress_bar = st.progress(0)
28
+ total_files = len(uploaded_files)
29
+
30
+ # Start the stopwatch
31
+ start_time = time.time()
32
+
33
+ for i, uploaded_file in enumerate(uploaded_files):
34
+ # Display which file is being processed
35
+ st.write(f"Processing file {i + 1}/{total_files}: {uploaded_file.name}")
36
+
37
+ # Read the file content
38
+ content = uploaded_file.read().decode("utf-8")
39
+
40
+ # Prepare the payload for the Ollama API
41
+ payload = {
42
+ "model": model,
43
+ "prompt": prompt_template.format(input=content),
44
+ "stream": False,
45
+ "format": schema
46
+ }
47
+
48
+ # Send the request to Ollama
49
+ response = requests.post(url, json=payload, headers={"Content-Type": "application/json"})
50
+
51
+ if response.status_code == 200:
52
+ # Parse the response and extract sentiment and reasoning
53
+ json_response = parse_response_to_json(response)
54
+ sentiment = json_response['sentiment']
55
+ reasoning = json_response['reasoning']
56
+
57
+ # Append the result to the CSV file
58
+ with open(output_file_path, mode="a", newline="", encoding="utf-8") as file:
59
+ writer = csv.writer(file)
60
+ writer.writerow([content, sentiment, reasoning])
61
+ else:
62
+ st.error(f"Error processing file {uploaded_file.name}: {response.status_code}")
63
+
64
+ # Update progress bar
65
+ progress_bar.progress((i + 1) / total_files)
66
+
67
+ # Stop the stopwatch and calculate elapsed time
68
+ elapsed_time = time.time() - start_time
69
+ st.sidebar.write(f"Processing time: {elapsed_time:.2f} seconds")
70
+
71
+ st.success("All files processed successfully!")
72
+
73
+ # Streamlit app title and description
74
+ st.title("Text File Sentiment Analysis with Ollama")
75
+ st.write("""
76
+ This app allows you to analyze the sentiment of text files using Ollama. Follow these steps:
77
+ 1. **Upload text files**: Drag and drop your text files.
78
+ 2. **Configure Ollama**: Set the API URL, model, prompt template, and schema in the sidebar.
79
+ 3. **Analyze Sentiment**: Click the "Analyze Sentiment" button to process the files.
80
+ 4. **Download Results**: Download the results as a CSV file.
81
+ 5. **Highlight Mismatches**: Use the options in the sidebar to highlight mismatches in the results.
82
+ """)
83
+
84
+ # User inputs for Ollama configuration
85
+ st.sidebar.header("Ollama Configuration")
86
+ url = st.sidebar.text_input("Ollama API URL", value="http://localhost:11434/api/generate")
87
+ model = st.sidebar.text_input("Model", value="llama3.2:latest")
88
+ prompt_template = st.sidebar.text_area(
89
+ "Prompt Template",
90
+ value="Do a sentiment analysis for the following text and return POSITIVE or NEGATIVE and your reasoning: {input}",
91
+ height=100
92
+ )
93
+
94
+ # Input field for the schema
95
+ default_schema = {
96
+ "type": "object",
97
+ "properties": {
98
+ "sentiment": {"enum": ["POSITIVE", "NEUTRAL", "NEGATIVE"]},
99
+ "reasoning": {"type": "string"}
100
+ },
101
+ "required": ["sentiment", "reasoning"]
102
+ }
103
+ schema_input = st.sidebar.text_area(
104
+ "Schema (JSON format)",
105
+ value=json.dumps(default_schema, indent=2),
106
+ height=400 # Increased height
107
+ )
108
+
109
+ # Parse the schema input
110
+ try:
111
+ schema = json.loads(schema_input)
112
+ except json.JSONDecodeError:
113
+ st.error("Invalid JSON schema. Please check your input.")
114
+ schema = default_schema
115
+
116
+ # Highlighting configuration in the sidebar
117
+ st.sidebar.header("Highlighting Configuration")
118
+ highlight_whole_row = st.sidebar.checkbox("Highlight the whole row", value=True)
119
+ highlight_color = st.sidebar.color_picker("Choose a highlight color", "#FF0000")
120
+
121
+ # File uploader for text files
122
+ uploaded_files = st.file_uploader("Upload text files", type=["txt"], accept_multiple_files=True)
123
+
124
+ # Initialize session state for results_df and output_file_name
125
+ if "results_df" not in st.session_state:
126
+ st.session_state.results_df = None
127
+ if "output_file_name" not in st.session_state:
128
+ st.session_state.output_file_name = None
129
+ if "uploaded_csv_df" not in st.session_state:
130
+ st.session_state.uploaded_csv_df = None
131
+
132
+ # Create tabs for the app
133
+ tab1, tab2 = st.tabs(["Analyze", "Highlight Mismatches"])
134
+
135
+ with tab1:
136
+ if uploaded_files:
137
+ # Generate a unique output file name with timestamp
138
+ if st.session_state.output_file_name is None:
139
+ st.session_state.output_file_name = f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
140
+
141
+ # Process the uploaded files with Ollama
142
+ if st.button("Analyze Sentiment"):
143
+ with st.spinner("Processing files..."):
144
+ process_files_with_ollama(
145
+ uploaded_files, url, model, prompt_template, schema, st.session_state.output_file_name
146
+ )
147
+
148
+ # Load the results into a DataFrame and store it in session state
149
+ st.session_state.results_df = pd.read_csv(st.session_state.output_file_name)
150
+
151
+ # Display the results from the output CSV file if available
152
+ if st.session_state.results_df is not None:
153
+ st.write("Sentiment Analysis Results:")
154
+ st.dataframe(st.session_state.results_df)
155
+
156
+ # Provide a download link for the results CSV file
157
+ if st.session_state.output_file_name is not None:
158
+ with open(st.session_state.output_file_name, "rb") as file:
159
+ st.download_button(
160
+ label="Download Results CSV",
161
+ data=file,
162
+ file_name=st.session_state.output_file_name,
163
+ mime="text/csv",
164
+ )
165
+
166
+ with tab2:
167
+ # Allow users to upload their own CSV file
168
+ uploaded_csv = st.file_uploader("Upload your own CSV file (optional)", type=["csv"])
169
+
170
+ # Use the uploaded CSV file if provided
171
+ if uploaded_csv is not None:
172
+ st.session_state.uploaded_csv_df = pd.read_csv(uploaded_csv)
173
+ st.write("Using the uploaded CSV file for highlighting mismatches.")
174
+ elif st.session_state.uploaded_csv_df is not None:
175
+ st.write("Using the previously uploaded CSV file for highlighting mismatches.")
176
+ else:
177
+ st.warning("No CSV file available. Please analyze text files in Tab 1 or upload a CSV file.")
178
+
179
+ # Display the results from the CSV file if available
180
+ if st.session_state.uploaded_csv_df is not None:
181
+ st.write("### Sentiment Analysis Results")
182
+ st.dataframe(st.session_state.uploaded_csv_df)
183
+
184
+ st.write("### Highlight Mismatches in Results")
185
+ column_to_check = st.selectbox(
186
+ "Select the column to check (e.g., Sentiment)",
187
+ options=st.session_state.uploaded_csv_df.columns,
188
+ index=1, # Default to the "Sentiment" column
189
+ )
190
+ constant_value = st.text_input(
191
+ "Enter the constant value to compare against (e.g., POSITIVE)",
192
+ value="POSITIVE", # Default value
193
+ )
194
+
195
+ if st.button("Highlight Mismatches"):
196
+ # Create a new Excel workbook and select the active worksheet
197
+ wb = Workbook()
198
+ ws = wb.active
199
+
200
+ # Write the header row to the Excel worksheet
201
+ for col_idx, header in enumerate(st.session_state.uploaded_csv_df.columns, start=1):
202
+ ws.cell(row=1, column=col_idx, value=header)
203
+
204
+ # Define the fill style using the selected color
205
+ highlight_fill = PatternFill(start_color=highlight_color.lstrip("#"), end_color=highlight_color.lstrip("#"), fill_type="solid")
206
+
207
+ # Initialize a list to store mismatched row numbers
208
+ mismatched_rows = []
209
+
210
+ # Write the results to the Excel worksheet
211
+ for row_idx, row in st.session_state.uploaded_csv_df.iterrows():
212
+ for col_idx, value in enumerate(row, start=1):
213
+ ws.cell(row=row_idx + 2, column=col_idx, value=value)
214
+
215
+ # Check for mismatches in the selected column
216
+ if row[column_to_check] != constant_value:
217
+ # Add the row number to the mismatched_rows list
218
+ mismatched_rows.append(row_idx + 2) # +2 because header is row 1
219
+
220
+ # Highlight the cell or the entire row based on user choice
221
+ if highlight_whole_row:
222
+ for col_idx in range(1, len(row) + 1):
223
+ ws.cell(row=row_idx + 2, column=col_idx).fill = highlight_fill
224
+ else:
225
+ col_index = st.session_state.uploaded_csv_df.columns.get_loc(column_to_check) + 1
226
+ ws.cell(row=row_idx + 2, column=col_index).fill = highlight_fill
227
+
228
+ # Save the modified workbook
229
+ highlighted_output_file = f"highlighted_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
230
+ wb.save(highlighted_output_file)
231
+
232
+ # Create a new Excel file containing only the mismatched rows
233
+ mismatched_df = st.session_state.uploaded_csv_df[st.session_state.uploaded_csv_df[column_to_check] != constant_value]
234
+ mismatched_output_file = f"mismatched_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
235
+ mismatched_df.to_excel(mismatched_output_file, index=False)
236
+
237
+ # Display the number of mismatches and their row numbers
238
+ st.write(f"**Total mismatches found:** {len(mismatched_rows)}")
239
+ if mismatched_rows:
240
+ st.write(f"**Mismatches found in rows (referring to the Excel file):** {', '.join(map(str, mismatched_rows))}")
241
+
242
+ # Provide a download link for the modified Excel file
243
+ with open(highlighted_output_file, "rb") as file:
244
+ st.download_button(
245
+ label="Download Highlighted Results",
246
+ data=file,
247
+ file_name=highlighted_output_file,
248
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
249
+ )
250
+
251
+ # Provide a download link for the mismatched rows Excel file
252
+ with open(mismatched_output_file, "rb") as file:
253
+ st.download_button(
254
+ label="Download Mismatched Rows",
255
+ data=file,
256
+ file_name=mismatched_output_file,
257
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
258
+ )
259
+
260
+ st.success("Mismatches highlighted! Click the buttons above to download.")