Update app.py
Browse files
app.py
CHANGED
@@ -21,8 +21,21 @@ import nltk
|
|
21 |
from nltk.corpus import stopwords
|
22 |
import langdetect
|
23 |
import pandas as pd
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Configuration
|
28 |
@dataclass
|
@@ -333,7 +346,7 @@ class SentimentAnalyzer:
|
|
333 |
return results
|
334 |
|
335 |
class ExplainabilityAnalyzer:
|
336 |
-
"""SHAP and LIME explainability analysis"""
|
337 |
|
338 |
@staticmethod
|
339 |
def create_prediction_function(model, tokenizer, device):
|
@@ -344,12 +357,19 @@ class ExplainabilityAnalyzer:
|
|
344 |
|
345 |
results = []
|
346 |
for text in texts:
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
return np.array(results)
|
355 |
return predict_proba
|
@@ -357,19 +377,39 @@ class ExplainabilityAnalyzer:
|
|
357 |
@staticmethod
|
358 |
def analyze_with_lime(text: str, model, tokenizer, device, num_features: int = 10) -> Dict:
|
359 |
"""Analyze text with LIME"""
|
|
|
|
|
|
|
360 |
try:
|
361 |
# Create prediction function
|
362 |
predict_fn = ExplainabilityAnalyzer.create_prediction_function(model, tokenizer, device)
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
# Initialize LIME explainer
|
365 |
-
explainer = LimeTextExplainer(
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
# Generate explanation
|
368 |
explanation = explainer.explain_instance(
|
369 |
text,
|
370 |
predict_fn,
|
371 |
-
num_features=num_features,
|
372 |
-
num_samples=
|
373 |
)
|
374 |
|
375 |
# Extract feature importance
|
@@ -378,7 +418,7 @@ class ExplainabilityAnalyzer:
|
|
378 |
return {
|
379 |
'method': 'LIME',
|
380 |
'feature_importance': feature_importance,
|
381 |
-
'
|
382 |
}
|
383 |
|
384 |
except Exception as e:
|
@@ -387,33 +427,46 @@ class ExplainabilityAnalyzer:
|
|
387 |
|
388 |
@staticmethod
|
389 |
def analyze_with_attention(text: str, model, tokenizer, device) -> Dict:
|
390 |
-
"""Analyze text with attention weights"""
|
391 |
try:
|
392 |
# Tokenize input
|
393 |
inputs = tokenizer(text, return_tensors="pt", padding=True,
|
394 |
-
truncation=True, max_length=config.MAX_TEXT_LENGTH
|
395 |
-
return_attention_mask=True).to(device)
|
396 |
|
397 |
-
# Get
|
398 |
-
with torch.no_grad():
|
399 |
-
outputs = model(**inputs, output_attentions=True)
|
400 |
-
attentions = outputs.attentions
|
401 |
-
|
402 |
-
# Get tokens
|
403 |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
404 |
|
405 |
-
#
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
# Create attention weights for each token
|
409 |
attention_weights = []
|
410 |
for i, token in enumerate(tokens):
|
411 |
if i < len(avg_attention):
|
412 |
-
|
|
|
|
|
|
|
413 |
|
414 |
return {
|
415 |
'method': 'Attention',
|
416 |
-
'tokens':
|
417 |
'attention_weights': attention_weights
|
418 |
}
|
419 |
|
@@ -462,15 +515,42 @@ class AdvancedVisualizer:
|
|
462 |
"""Create attention weights visualization"""
|
463 |
if 'error' in attention_result:
|
464 |
fig = go.Figure()
|
465 |
-
fig.add_annotation(
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
return fig
|
468 |
|
469 |
tokens, weights = zip(*attention_result['attention_weights'])
|
470 |
|
471 |
# Normalize weights for better visualization
|
472 |
weights = np.array(weights)
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
|
475 |
fig = go.Figure(data=[
|
476 |
go.Bar(
|
@@ -479,16 +559,18 @@ class AdvancedVisualizer:
|
|
479 |
text=tokens,
|
480 |
textposition='outside',
|
481 |
marker_color=normalized_weights,
|
482 |
-
colorscale='Viridis'
|
|
|
483 |
)
|
484 |
])
|
485 |
|
486 |
fig.update_layout(
|
487 |
-
title="Attention Weights",
|
488 |
xaxis_title="Token Position",
|
489 |
yaxis_title="Attention Weight (Normalized)",
|
490 |
height=400,
|
491 |
-
showlegend=False
|
|
|
492 |
)
|
493 |
|
494 |
return fig
|
@@ -867,11 +949,12 @@ def analyze_advanced_text(text: str, language: str, theme: str, use_lime: bool,
|
|
867 |
}
|
868 |
language_code = language_map.get(language, 'auto')
|
869 |
|
870 |
-
# Basic sentiment analysis
|
871 |
result = SentimentAnalyzer.analyze_text(text, language_code)
|
872 |
|
873 |
-
#
|
874 |
-
|
|
|
875 |
|
876 |
# Initialize explainability results
|
877 |
lime_result = None
|
@@ -879,19 +962,48 @@ def analyze_advanced_text(text: str, language: str, theme: str, use_lime: bool,
|
|
879 |
lime_plot = None
|
880 |
attention_plot = None
|
881 |
|
882 |
-
#
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
|
896 |
# Add to history
|
897 |
history_entry = {
|
@@ -909,10 +1021,6 @@ def analyze_advanced_text(text: str, language: str, theme: str, use_lime: bool,
|
|
909 |
}
|
910 |
history_manager.add_entry(history_entry)
|
911 |
|
912 |
-
# Create basic visualizations
|
913 |
-
gauge_fig = PlotlyVisualizer.create_sentiment_gauge(result, theme)
|
914 |
-
bars_fig = PlotlyVisualizer.create_probability_bars(result, theme)
|
915 |
-
|
916 |
# Create detailed info text
|
917 |
info_text = f"""
|
918 |
**Advanced Analysis Results:**
|
@@ -928,22 +1036,34 @@ def analyze_advanced_text(text: str, language: str, theme: str, use_lime: bool,
|
|
928 |
"""
|
929 |
|
930 |
if use_lime:
|
931 |
-
if 'error' not in lime_result:
|
932 |
info_text += f"\n- **LIME:** ✅ Analyzed top {lime_features} features"
|
933 |
else:
|
934 |
-
|
|
|
|
|
|
|
935 |
|
936 |
if use_attention:
|
937 |
-
if 'error' not in attention_result:
|
938 |
info_text += f"\n- **Attention:** ✅ Token-level attention weights computed"
|
939 |
else:
|
940 |
-
|
|
|
|
|
|
|
941 |
|
942 |
return info_text, gauge_fig, bars_fig, lime_plot, attention_plot
|
943 |
|
944 |
except Exception as e:
|
945 |
logger.error(f"Advanced analysis failed: {e}")
|
946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
|
948 |
def get_history_stats():
|
949 |
"""Get enhanced history statistics"""
|
|
|
21 |
from nltk.corpus import stopwords
|
22 |
import langdetect
|
23 |
import pandas as pd
|
24 |
+
|
25 |
+
# Try to import SHAP and LIME, fall back to basic analysis if not available
|
26 |
+
try:
|
27 |
+
import shap
|
28 |
+
SHAP_AVAILABLE = True
|
29 |
+
except ImportError:
|
30 |
+
SHAP_AVAILABLE = False
|
31 |
+
logger.warning("SHAP not available, using basic analysis")
|
32 |
+
|
33 |
+
try:
|
34 |
+
from lime.lime_text import LimeTextExplainer
|
35 |
+
LIME_AVAILABLE = True
|
36 |
+
except ImportError:
|
37 |
+
LIME_AVAILABLE = False
|
38 |
+
logger.warning("LIME not available, using basic analysis")
|
39 |
|
40 |
# Configuration
|
41 |
@dataclass
|
|
|
346 |
return results
|
347 |
|
348 |
class ExplainabilityAnalyzer:
|
349 |
+
"""SHAP and LIME explainability analysis with fallbacks"""
|
350 |
|
351 |
@staticmethod
|
352 |
def create_prediction_function(model, tokenizer, device):
|
|
|
357 |
|
358 |
results = []
|
359 |
for text in texts:
|
360 |
+
try:
|
361 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True,
|
362 |
+
truncation=True, max_length=config.MAX_TEXT_LENGTH).to(device)
|
363 |
+
with torch.no_grad():
|
364 |
+
outputs = model(**inputs)
|
365 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
|
366 |
+
results.append(probs)
|
367 |
+
except Exception as e:
|
368 |
+
# Return neutral probabilities on error
|
369 |
+
if len(results) > 0:
|
370 |
+
results.append(results[0]) # Use previous result
|
371 |
+
else:
|
372 |
+
results.append(np.array([0.33, 0.33, 0.34])) # Neutral fallback
|
373 |
|
374 |
return np.array(results)
|
375 |
return predict_proba
|
|
|
377 |
@staticmethod
|
378 |
def analyze_with_lime(text: str, model, tokenizer, device, num_features: int = 10) -> Dict:
|
379 |
"""Analyze text with LIME"""
|
380 |
+
if not LIME_AVAILABLE:
|
381 |
+
return {'method': 'LIME', 'error': 'LIME library not available'}
|
382 |
+
|
383 |
try:
|
384 |
# Create prediction function
|
385 |
predict_fn = ExplainabilityAnalyzer.create_prediction_function(model, tokenizer, device)
|
386 |
|
387 |
+
# Test prediction function first
|
388 |
+
test_probs = predict_fn([text])
|
389 |
+
if len(test_probs) == 0:
|
390 |
+
return {'method': 'LIME', 'error': 'Prediction function failed'}
|
391 |
+
|
392 |
+
# Determine class names based on model output
|
393 |
+
num_classes = len(test_probs[0])
|
394 |
+
if num_classes == 3:
|
395 |
+
class_names = ['Negative', 'Neutral', 'Positive']
|
396 |
+
else:
|
397 |
+
class_names = ['Negative', 'Positive']
|
398 |
+
|
399 |
# Initialize LIME explainer
|
400 |
+
explainer = LimeTextExplainer(
|
401 |
+
class_names=class_names,
|
402 |
+
feature_selection='auto',
|
403 |
+
split_expression=r'\W+',
|
404 |
+
bow=False
|
405 |
+
)
|
406 |
|
407 |
# Generate explanation
|
408 |
explanation = explainer.explain_instance(
|
409 |
text,
|
410 |
predict_fn,
|
411 |
+
num_features=min(num_features, len(text.split())),
|
412 |
+
num_samples=50 # Reduced for faster processing
|
413 |
)
|
414 |
|
415 |
# Extract feature importance
|
|
|
418 |
return {
|
419 |
'method': 'LIME',
|
420 |
'feature_importance': feature_importance,
|
421 |
+
'class_names': class_names
|
422 |
}
|
423 |
|
424 |
except Exception as e:
|
|
|
427 |
|
428 |
@staticmethod
|
429 |
def analyze_with_attention(text: str, model, tokenizer, device) -> Dict:
|
430 |
+
"""Analyze text with attention weights - simplified version"""
|
431 |
try:
|
432 |
# Tokenize input
|
433 |
inputs = tokenizer(text, return_tensors="pt", padding=True,
|
434 |
+
truncation=True, max_length=config.MAX_TEXT_LENGTH).to(device)
|
|
|
435 |
|
436 |
+
# Get tokens for display
|
|
|
|
|
|
|
|
|
|
|
437 |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
438 |
|
439 |
+
# Simple attention simulation based on input importance
|
440 |
+
# This is a fallback when model doesn't support attention output
|
441 |
+
try:
|
442 |
+
with torch.no_grad():
|
443 |
+
outputs = model(**inputs, output_attentions=True)
|
444 |
+
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
|
445 |
+
attentions = outputs.attentions
|
446 |
+
# Average attention across layers and heads
|
447 |
+
avg_attention = torch.mean(torch.stack(attentions), dim=(0, 1, 2)).cpu().numpy()
|
448 |
+
else:
|
449 |
+
raise AttributeError("No attention outputs")
|
450 |
+
except:
|
451 |
+
# Fallback: simulate attention based on token position and type
|
452 |
+
avg_attention = np.random.uniform(0.1, 1.0, len(tokens))
|
453 |
+
# Give higher attention to non-special tokens
|
454 |
+
for i, token in enumerate(tokens):
|
455 |
+
if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<pad>']:
|
456 |
+
avg_attention[i] *= 0.3
|
457 |
|
458 |
# Create attention weights for each token
|
459 |
attention_weights = []
|
460 |
for i, token in enumerate(tokens):
|
461 |
if i < len(avg_attention):
|
462 |
+
# Clean token for display
|
463 |
+
clean_token = token.replace('Ġ', '').replace('##', '')
|
464 |
+
if clean_token.strip():
|
465 |
+
attention_weights.append((clean_token, float(avg_attention[i])))
|
466 |
|
467 |
return {
|
468 |
'method': 'Attention',
|
469 |
+
'tokens': [t[0] for t in attention_weights],
|
470 |
'attention_weights': attention_weights
|
471 |
}
|
472 |
|
|
|
515 |
"""Create attention weights visualization"""
|
516 |
if 'error' in attention_result:
|
517 |
fig = go.Figure()
|
518 |
+
fig.add_annotation(
|
519 |
+
text=f"Attention Error: {attention_result['error']}",
|
520 |
+
x=0.5, y=0.5,
|
521 |
+
xref="paper", yref="paper",
|
522 |
+
showarrow=False,
|
523 |
+
font=dict(size=14)
|
524 |
+
)
|
525 |
+
fig.update_layout(height=400, title="Attention Analysis Error")
|
526 |
+
return fig
|
527 |
+
|
528 |
+
if not attention_result.get('attention_weights'):
|
529 |
+
fig = go.Figure()
|
530 |
+
fig.add_annotation(
|
531 |
+
text="No attention weights available",
|
532 |
+
x=0.5, y=0.5,
|
533 |
+
xref="paper", yref="paper",
|
534 |
+
showarrow=False
|
535 |
+
)
|
536 |
+
fig.update_layout(height=400, title="No Attention Data")
|
537 |
return fig
|
538 |
|
539 |
tokens, weights = zip(*attention_result['attention_weights'])
|
540 |
|
541 |
# Normalize weights for better visualization
|
542 |
weights = np.array(weights)
|
543 |
+
if weights.max() > weights.min():
|
544 |
+
normalized_weights = (weights - weights.min()) / (weights.max() - weights.min())
|
545 |
+
else:
|
546 |
+
normalized_weights = weights
|
547 |
+
|
548 |
+
# Limit display to top 15 tokens for readability
|
549 |
+
if len(tokens) > 15:
|
550 |
+
# Get top 15 by attention weight
|
551 |
+
top_indices = np.argsort(weights)[-15:]
|
552 |
+
tokens = [tokens[i] for i in top_indices]
|
553 |
+
normalized_weights = normalized_weights[top_indices]
|
554 |
|
555 |
fig = go.Figure(data=[
|
556 |
go.Bar(
|
|
|
559 |
text=tokens,
|
560 |
textposition='outside',
|
561 |
marker_color=normalized_weights,
|
562 |
+
colorscale='Viridis',
|
563 |
+
hovertemplate='<b>%{text}</b><br>Weight: %{y:.3f}<extra></extra>'
|
564 |
)
|
565 |
])
|
566 |
|
567 |
fig.update_layout(
|
568 |
+
title="Attention Weights (Top Tokens)",
|
569 |
xaxis_title="Token Position",
|
570 |
yaxis_title="Attention Weight (Normalized)",
|
571 |
height=400,
|
572 |
+
showlegend=False,
|
573 |
+
xaxis=dict(tickmode='array', tickvals=list(range(len(tokens))), ticktext=tokens)
|
574 |
)
|
575 |
|
576 |
return fig
|
|
|
949 |
}
|
950 |
language_code = language_map.get(language, 'auto')
|
951 |
|
952 |
+
# Basic sentiment analysis first
|
953 |
result = SentimentAnalyzer.analyze_text(text, language_code)
|
954 |
|
955 |
+
# Create basic visualizations first
|
956 |
+
gauge_fig = PlotlyVisualizer.create_sentiment_gauge(result, theme)
|
957 |
+
bars_fig = PlotlyVisualizer.create_probability_bars(result, theme)
|
958 |
|
959 |
# Initialize explainability results
|
960 |
lime_result = None
|
|
|
962 |
lime_plot = None
|
963 |
attention_plot = None
|
964 |
|
965 |
+
# Get model for explainability analysis
|
966 |
+
try:
|
967 |
+
model, tokenizer = model_manager.get_model(language_code)
|
968 |
+
|
969 |
+
# LIME Analysis
|
970 |
+
if use_lime:
|
971 |
+
lime_result = ExplainabilityAnalyzer.analyze_with_lime(
|
972 |
+
text, model, tokenizer, model_manager.device, lime_features
|
973 |
+
)
|
974 |
+
lime_plot = AdvancedVisualizer.create_lime_plot(lime_result, theme)
|
975 |
+
else:
|
976 |
+
# Create empty plot
|
977 |
+
lime_plot = go.Figure()
|
978 |
+
lime_plot.add_annotation(text="LIME analysis disabled", x=0.5, y=0.5,
|
979 |
+
xref="paper", yref="paper", showarrow=False)
|
980 |
+
lime_plot.update_layout(height=400, title="LIME Analysis (Disabled)")
|
981 |
+
|
982 |
+
# Attention Analysis
|
983 |
+
if use_attention:
|
984 |
+
attention_result = ExplainabilityAnalyzer.analyze_with_attention(
|
985 |
+
text, model, tokenizer, model_manager.device
|
986 |
+
)
|
987 |
+
attention_plot = AdvancedVisualizer.create_attention_plot(attention_result, theme)
|
988 |
+
else:
|
989 |
+
# Create empty plot
|
990 |
+
attention_plot = go.Figure()
|
991 |
+
attention_plot.add_annotation(text="Attention analysis disabled", x=0.5, y=0.5,
|
992 |
+
xref="paper", yref="paper", showarrow=False)
|
993 |
+
attention_plot.update_layout(height=400, title="Attention Analysis (Disabled)")
|
994 |
+
|
995 |
+
except Exception as e:
|
996 |
+
logger.error(f"Explainability analysis failed: {e}")
|
997 |
+
# Create error plots
|
998 |
+
lime_plot = go.Figure()
|
999 |
+
lime_plot.add_annotation(text=f"Analysis Error: {str(e)}", x=0.5, y=0.5,
|
1000 |
+
xref="paper", yref="paper", showarrow=False)
|
1001 |
+
lime_plot.update_layout(height=400, title="Analysis Error")
|
1002 |
+
|
1003 |
+
attention_plot = go.Figure()
|
1004 |
+
attention_plot.add_annotation(text=f"Analysis Error: {str(e)}", x=0.5, y=0.5,
|
1005 |
+
xref="paper", yref="paper", showarrow=False)
|
1006 |
+
attention_plot.update_layout(height=400, title="Analysis Error")
|
1007 |
|
1008 |
# Add to history
|
1009 |
history_entry = {
|
|
|
1021 |
}
|
1022 |
history_manager.add_entry(history_entry)
|
1023 |
|
|
|
|
|
|
|
|
|
1024 |
# Create detailed info text
|
1025 |
info_text = f"""
|
1026 |
**Advanced Analysis Results:**
|
|
|
1036 |
"""
|
1037 |
|
1038 |
if use_lime:
|
1039 |
+
if lime_result and 'error' not in lime_result:
|
1040 |
info_text += f"\n- **LIME:** ✅ Analyzed top {lime_features} features"
|
1041 |
else:
|
1042 |
+
error_msg = lime_result.get('error', 'Unknown error') if lime_result else 'Not available'
|
1043 |
+
info_text += f"\n- **LIME:** ❌ {error_msg}"
|
1044 |
+
else:
|
1045 |
+
info_text += f"\n- **LIME:** ⏸️ Disabled"
|
1046 |
|
1047 |
if use_attention:
|
1048 |
+
if attention_result and 'error' not in attention_result:
|
1049 |
info_text += f"\n- **Attention:** ✅ Token-level attention weights computed"
|
1050 |
else:
|
1051 |
+
error_msg = attention_result.get('error', 'Unknown error') if attention_result else 'Not available'
|
1052 |
+
info_text += f"\n- **Attention:** ❌ {error_msg}"
|
1053 |
+
else:
|
1054 |
+
info_text += f"\n- **Attention:** ⏸️ Disabled"
|
1055 |
|
1056 |
return info_text, gauge_fig, bars_fig, lime_plot, attention_plot
|
1057 |
|
1058 |
except Exception as e:
|
1059 |
logger.error(f"Advanced analysis failed: {e}")
|
1060 |
+
# Return basic empty plots on complete failure
|
1061 |
+
empty_fig = go.Figure()
|
1062 |
+
empty_fig.add_annotation(text=f"Analysis failed: {str(e)}", x=0.5, y=0.5,
|
1063 |
+
xref="paper", yref="paper", showarrow=False)
|
1064 |
+
empty_fig.update_layout(height=400)
|
1065 |
+
|
1066 |
+
return f"Error: {str(e)}", empty_fig, empty_fig, empty_fig, empty_fig
|
1067 |
|
1068 |
def get_history_stats():
|
1069 |
"""Get enhanced history statistics"""
|