Sa-m's picture
Update app.py
9f93f0e verified
raw
history blame
20.6 kB
import random
import matplotlib.pyplot as plt
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.text import Text
from nltk.probability import FreqDist
from cleantext import clean
import textract
import urllib.request
from io import BytesIO
import sys
import pandas as pd
import cv2
import re
from wordcloud import WordCloud, ImageColorGenerator
from textblob import TextBlob
from PIL import Image
import os
import gradio as gr
from dotenv import load_dotenv
import groq
import json
import traceback
import numpy as np
import unidecode
import contractions
from sklearn.feature_extraction.text import TfidfVectorizer
load_dotenv()
import nltk
import ssl
def ensure_nltk_resources():
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('corpora/stopwords')
except LookupError:
print("NLTK resources not found. Downloading...")
try:
# Handling potential SSL issues (common on some systems)
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download(['stopwords', 'wordnet', 'words'])
nltk.download('punkt')
nltk.download('punkt_tab')
print("NLTK resources downloaded successfully.")
ensure_nltk_resources()
# Download NLTK resources (Ensure this runs once or handle caching)
# nltk.download(['stopwords', 'wordnet', 'words'])
# nltk.download('punkt')
# nltk.download('punkt_tab')
# Initialize Groq client
groq_api_key = os.getenv("GROQ_API_KEY")
groq_client = groq.Groq(api_key=groq_api_key) if groq_api_key else None
# Stopwords customization
stop_words = set(stopwords.words('english'))
stop_words.update({'ask', 'much', 'thank', 'etc.', 'e', 'We', 'In', 'ed', 'pa', 'This', 'also', 'A', 'fu', 'To', '5', 'ing', 'er', '2'}) # Ensure stop_words is a set
# --- Parsing & Preprocessing Functions ---
def Parsing(parsed_text):
try:
if hasattr(parsed_text, 'name'):
file_path = parsed_text.name
else:
file_path = parsed_text
# Ensure textract handles encoding correctly or handle errors
raw_party = textract.process(file_path) # Removed encoding/method for broader compatibility
decoded_text = raw_party.decode('utf-8', errors='ignore') # Decode bytes to string, handling errors
return clean(decoded_text) # Pass decoded string to clean
except Exception as e:
print(f"Error parsing PDF: {e}")
return f"Error parsing PDF: {e}"
def clean_text(text):
text = text.encode("ascii", errors="ignore").decode("ascii")
text = unidecode.unidecode(text)
text = contractions.fix(text)
text = re.sub(r"\n", " ", text)
text = re.sub(r"\t", " ", text)
text = re.sub(r"/ ", " ", text)
text = text.strip()
text = re.sub(" +", " ", text).strip()
text = [word for word in text.split() if word not in stop_words]
return ' '.join(text)
def Preprocess(textParty):
text1Party = re.sub('[^A-Za-z0-9]+', ' ', textParty)
pattern = re.compile(r'\b(' + r'|'.join(stopwords.words('english')) + r')\b\s*')
text2Party = pattern.sub('', text1Party)
return text2Party
# --- Core Analysis Functions ---
def generate_summary(text):
if not groq_client:
return "Summarization is not available. Please set up your GROQ_API_KEY in the .env file."
if len(text) > 10000:
text = text[:10000]
try:
completion = groq_client.chat.completions.create(
model="llama3-8b-8192", # Or your preferred model
messages=[
{"role": "system", "content": "You are a helpful assistant that summarizes political manifestos. Provide a concise, objective summary that captures the key policy proposals, themes, and promises in the manifesto."},
{"role": "user", "content": f"Please summarize the following political manifesto text in about 300-500 words, focusing on the main policy areas, promises, and themes:\n{text}"}
],
temperature=0.3,
max_tokens=800
)
return completion.choices[0].message.content
except Exception as e:
return f"Error generating summary: {str(e)}"
def fDistance(text2Party):
word_tokens_party = word_tokenize(text2Party)
fdistance = FreqDist(word_tokens_party).most_common(10)
mem = {x[0]: x[1] for x in fdistance}
vectorizer = TfidfVectorizer(max_features=15, stop_words='english')
try:
tfidf_matrix = vectorizer.fit_transform(sent_tokenize(text2Party))
feature_names = vectorizer.get_feature_names_out()
tfidf_scores = {}
sentences = sent_tokenize(text2Party)
for i, word in enumerate(feature_names):
scores = []
for j in range(tfidf_matrix.shape[0]): # Iterate through sentences
if i < tfidf_matrix.shape[1]: # Check if word index is valid for this sentence vector
scores.append(tfidf_matrix[j, i])
if scores:
tfidf_scores[word] = sum(scores) / len(scores) # Average TF-IDF score across sentences
combined_scores = {}
all_words = set(list(mem.keys()) + list(tfidf_scores.keys()))
max_freq = max(mem.values()) if mem else 1
max_tfidf = max(tfidf_scores.values()) if tfidf_scores else 1
for word in all_words:
freq_score = mem.get(word, 0) / max_freq
tfidf_score = tfidf_scores.get(word, 0) / max_tfidf
combined_scores[word] = (freq_score * 0.3) + (tfidf_score * 0.7)
top_words = dict(sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:10])
return normalize(top_words)
except ValueError as ve: # Handle case where TF-IDF fails (e.g., empty after processing)
print(f"Warning: TF-IDF failed, using only frequency: {ve}")
# Fallback to just normalized frequency if TF-IDF fails
if mem:
max_freq = max(mem.values())
return {k: v / max_freq for k, v in list(mem.items())[:10]} # Return top 10 freq, normalized
else:
return {}
def normalize(d, target=1.0):
raw = sum(d.values())
factor = target / raw if raw != 0 else 0
return {key: value * factor for key, value in d.items()}
# --- Visualization Functions with Error Handling ---
def safe_plot(func, *args, **kwargs):
try:
plt.clf()
func(*args, **kwargs)
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight') # Add bbox_inches for better fit
buf.seek(0)
img = Image.open(buf)
plt.close() # Use plt.close() instead of clf for better memory management after save
return img
except Exception as e:
print(f"Plotting error in safe_plot: {e}")
traceback.print_exc() # Print traceback for debugging
return None # Return None on error
def fDistancePlot(text2Party):
def plot_func():
tokens = word_tokenize(text2Party)
if not tokens:
plt.text(0.5, 0.5, "No data to plot", ha='center', va='center')
return
fdist = FreqDist(tokens)
fdist.plot(15, title='Frequency Distribution')
plt.xticks(rotation=45, ha='right') # Rotate x-axis labels if needed
plt.tight_layout()
return safe_plot(plot_func)
def DispersionPlot(textParty):
try:
word_tokens_party = word_tokenize(textParty)
if not word_tokens_party:
return None
moby = Text(word_tokens_party)
fdistance = FreqDist(word_tokens_party)
# Get top 5 words, handle potential IndexError if less than 5 unique words
common_words = fdistance.most_common(6)
if len(common_words) < 5:
word_Lst = [word for word, _ in common_words]
else:
word_Lst = [common_words[x][0] for x in range(5)]
if not word_Lst:
return None
plt.figure(figsize=(10, 5)) # Adjust figure size
plt.title('Dispersion Plot')
moby.dispersion_plot(word_Lst)
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close() # Close the figure
return img
except Exception as e:
print(f"Dispersion plot error: {e}")
traceback.print_exc()
return None
def word_cloud_generator(parsed_text_name, text_Party):
try:
# Handle case where parsed_text_name might not have .name
filename_lower = ""
if hasattr(parsed_text_name, 'name') and parsed_text_name.name:
filename_lower = parsed_text_name.name.lower()
elif isinstance(parsed_text_name, str):
filename_lower = parsed_text_name.lower()
mask_path = None
if 'bjp' in filename_lower:
mask_path = 'bjpImg2.jpeg'
elif 'congress' in filename_lower:
mask_path = 'congress3.jpeg'
elif 'aap' in filename_lower:
mask_path = 'aapMain2.jpg'
# Generate word cloud
if text_Party.strip() == "":
raise ValueError("Text for word cloud is empty")
if mask_path and os.path.exists(mask_path):
orgImg = Image.open(mask_path)
# Ensure mask is in the right format (e.g., uint8)
if orgImg.mode != 'RGB':
orgImg = orgImg.convert('RGB')
mask = np.array(orgImg)
wordcloud = WordCloud(max_words=3000, mask=mask, background_color='white').generate(text_Party) # Added background color
else:
wordcloud = WordCloud(max_words=2000, background_color='white').generate(text_Party)
plt.figure(figsize=(8, 6)) # Set figure size
plt.imshow(wordcloud, interpolation='bilinear') # Use bilinear interpolation
plt.axis("off")
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close() # Close the figure
return img
except Exception as e:
print(f"Word cloud error: {e}")
traceback.print_exc()
return None # Return None on error
# Initial design for concordance based search
def get_all_phases_containing_tar_wrd(target_word, tar_passage, left_margin=10, right_margin=10, numLins=4):
"""
Function to get all the phrases that contain the target word in a text/passage.
"""
if not target_word or target_word.strip() == "":
return "Please enter a search term"
tokens = nltk.word_tokenize(tar_passage)
text = nltk.Text(tokens)
c = nltk.ConcordanceIndex(text.tokens, key=lambda s: s.lower())
offsets = c.offsets(target_word)
if not offsets:
return f"Word '{target_word}' not found."
concordance_txt = [
text.tokens[max(0, offset - left_margin):offset + right_margin]
for offset in offsets[:numLins]
]
result = [' '.join(con_sub) for con_sub in concordance_txt]
return '\n'.join(result) # Use newline for better readability in textbox
def get_contextual_search_result(target_word, tar_passage, groq_client_instance, max_context_length=8000):
"""
Uses the LLM to provide contextual information about the target word within the passage.
"""
if not target_word or target_word.strip() == "":
return "Please enter a search term."
if not groq_client_instance:
return "Contextual search requires the LLM API. Please set up your GROQ_API_KEY."
# Basic check if word exists (optional, LLM can handle it too)
if target_word.lower() not in tar_passage.lower():
return f"The term '{target_word}' was not found in the manifesto text."
# Truncate passage if too long for the model/context window
# You might need to adjust this based on your model's limits and desired performance
if len(tar_passage) > max_context_length:
# Simple truncation; could be improved to ensure sentences are complete
tar_passage = tar_passage[:max_context_length]
print(f"Warning: Passage truncated for LLM search context to {max_context_length} characters.")
prompt = f"""
You are given a political manifesto text and a specific search term.
Your task is to find all relevant mentions of the search term in the text and provide a concise, informative summary of the context surrounding each mention.
Focus on the key ideas, policies, or points related to the search term.
If the term is not found or not relevant, state that clearly.
Search Term: {target_word}
Manifesto Text:
{tar_passage}
"""
try:
completion = groq_client_instance.chat.completions.create(
model="llama3-8b-8192", # Use the same or a suitable model
messages=[
{"role": "system", "content": "You are a helpful assistant skilled at analyzing political texts and extracting relevant information based on a search query."},
{"role": "user", "content": prompt}
],
temperature=0.2, # Low temperature for more factual extraction
max_tokens=1000 # Adjust based on expected output length
)
result = completion.choices[0].message.content.strip()
return result if result else f"No specific context for '{target_word}' could be generated."
except Exception as e:
error_msg = f"Error during contextual search for '{target_word}': {str(e)}"
print(error_msg)
traceback.print_exc()
# Fallback to concordance if LLM fails?
# return get_all_phases_containing_tar_wrd_fallback(target_word, tar_passage)
return error_msg # Or return the error message directly
def analysis(Manifesto, Search):
try:
if Manifesto is None:
# Ensure return order matches the outputs list
return "No file uploaded", {}, None, None, None, None, None, "No file uploaded"
if Search.strip() == "":
Search = "government"
raw_party = Parsing(Manifesto)
if isinstance(raw_party, str) and raw_party.startswith("Error"):
return raw_party, {}, None, None, None, None, None, "Parsing failed"
text_Party = clean_text(raw_party)
text_Party_processed = Preprocess(text_Party)
# --- Perform Search FIRST using the ORIGINAL text for better context ---
# Pass the original raw text for richer context to the LLM
searChRes = get_contextual_search_result(Search, raw_party, groq_client)
# --- Then proceed with other analyses ---
summary = generate_summary(raw_party) # Use raw_party for summary for more context?
# --- Sentiment Analysis ---
if not text_Party_processed.strip():
# Handle empty text after processing
df_dummy = pd.DataFrame({'Polarity_Label': ['Neutral'], 'Subjectivity_Label': ['Low']})
polarity_val = 0.0
subjectivity_val = 0.0
else:
polarity_val = TextBlob(text_Party_processed).sentiment.polarity
subjectivity_val = TextBlob(text_Party_processed).sentiment.subjectivity
polarity_label = 'Positive' if polarity_val > 0 else 'Negative' if polarity_val < 0 else 'Neutral'
subjectivity_label = 'High' if subjectivity_val > 0.5 else 'Low'
df_dummy = pd.DataFrame({'Polarity_Label': [polarity_label], 'Subjectivity_Label': [subjectivity_label]})
# --- Generate Plots with Safe Plotting ---
# Pass the potentially empty text and handle inside plotting functions
sentiment_plot = safe_plot(lambda: df_dummy['Polarity_Label'].value_counts().plot(kind='bar', color="#FF9F45", title='Sentiment Analysis'))
subjectivity_plot = safe_plot(lambda: df_dummy['Subjectivity_Label'].value_counts().plot(kind='bar', color="#B667F1", title='Subjectivity Analysis'))
freq_plot = fDistancePlot(text_Party_processed)
dispersion_plot = DispersionPlot(text_Party_processed)
wordcloud = word_cloud_generator(Manifesto, text_Party_processed) # Pass Manifesto object itself
fdist_Party = fDistance(text_Party_processed)
# searChRes is now generated earlier
return searChRes, fdist_Party, sentiment_plot, subjectivity_plot, wordcloud, freq_plot, dispersion_plot, summary
except Exception as e:
error_msg = f"Critical error in analysis function: {str(e)}"
print(error_msg)
traceback.print_exc()
# Return error messages/images in the correct order
return error_msg, {}, None, None, None, None, None, "Analysis failed"
# --- Gradio Interface (remains largely the same, just ensuring output variable names match) ---
# Use Blocks for custom layout
with gr.Blocks(title='Manifesto Analysis') as demo:
gr.Markdown("# Manifesto Analysis")
# Input Section
with gr.Row():
with gr.Column(scale=1): # Adjust scale if needed
file_input = gr.File(label="Upload Manifesto PDF", file_types=[".pdf"])
with gr.Column(scale=1):
search_input = gr.Textbox(label="Search Term", placeholder="Enter a term to search in the manifesto")
submit_btn = gr.Button("Analyze Manifesto", variant='primary') # Make button prominent
# Output Section using Tabs
with gr.Tabs():
# --- Summary Tab ---
with gr.TabItem("Summary"):
summary_output = gr.Textbox(label='AI-Generated Summary', lines=10, interactive=False)
# --- Search Results Tab ---
with gr.TabItem("Search Results"):
# Use the specific output variable defined in the layout
search_output = gr.Textbox(label='Context Based Search Results', lines=15, interactive=False, max_lines=20) # Increased lines/max_lines
# --- Key Topics Tab ---
with gr.TabItem("Key Topics"):
topics_output = gr.Label(label="Most Relevant Topics (LLM Enhanced)", num_top_classes=10) # Show top 10
# --- Visualizations Tab ---
with gr.TabItem("Visualizations"):
# Use Rows and Columns for better arrangement
with gr.Row(): # Row 1: Sentiment & Subjectivity
with gr.Column():
sentiment_output = gr.Image(label='Sentiment Analysis', interactive=False, height=400) # Set height
with gr.Column():
subjectivity_output = gr.Image(label='Subjectivity Analysis', interactive=False, height=400)
with gr.Row(): # Row 2: Word Cloud & Frequency
with gr.Column():
wordcloud_output = gr.Image(label='Word Cloud', interactive=False, height=400)
with gr.Column():
freq_output = gr.Image(label='Frequency Distribution', interactive=False, height=400)
with gr.Row(): # Row 3: Dispersion Plot (Full width)
with gr.Column():
dispersion_output = gr.Image(label='Dispersion Plot', interactive=False, height=400) # Adjust height as needed
# --- Link Button Click to Function and Outputs ---
# Ensure the order of outputs matches the function return order
submit_btn.click(
fn=analysis,
inputs=[file_input, search_input],
outputs=[
search_output, # 1 (Now contextual)
topics_output, # 2
sentiment_output, # 3
subjectivity_output, # 4
wordcloud_output, # 5
freq_output, # 6
dispersion_output, # 7
summary_output # 8
],
concurrency_limit=1 # Limit concurrent analyses if needed
)
# --- Examples ---
gr.Examples(
examples=[
["Example/AAP_Manifesto_2019.pdf", "government"],
["Example/Bjp_Manifesto_2019.pdf", "environment"],
["Example/Congress_Manifesto_2019.pdf", "safety"]
],
inputs=[file_input, search_input],
outputs=[search_output, topics_output, sentiment_output, subjectivity_output, wordcloud_output, freq_output, dispersion_output, summary_output], # Link examples to outputs
fn=analysis # Run analysis on example click
)
# Launch the app
if __name__ == "__main__":
demo.launch(debug=True, share=False, show_error=True)