janmariakowalski commited on
Commit
b711a13
·
verified ·
1 Parent(s): a6cb078

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio application for text classification, styled to be visually appealing.
4
+ This version uses only the 'sojka2' model.
5
+ """
6
+
7
+ import gradio as gr
8
+ import logging
9
+ import os
10
+ from typing import Dict, Tuple, Any
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
+ import numpy as np
14
+
15
+ try:
16
+ from peft import PeftModel
17
+ except ImportError:
18
+ PeftModel = None
19
+ logging.info("PEFT library not found. Loading models without PEFT support.")
20
+
21
+ # --- Configuration ---
22
+ # Model path is set to sojka2
23
+ MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
26
+ MAX_SEQ_LENGTH = 512
27
+ # Thresholds are now hardcoded
28
+ THRESHOLDS = {
29
+ "self-harm": 0.5,
30
+ "hate": 0.5,
31
+ "vulgar": 0.5,
32
+ "sex": 0.5,
33
+ "crime": 0.5,
34
+ }
35
+
36
+ # Set up logging
37
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
38
+ logger = logging.getLogger(__name__)
39
+
40
+ def load_model_and_tokenizer(model_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
41
+ """Load the trained model and tokenizer"""
42
+ logger.info(f"Loading model from {model_path}")
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
45
+
46
+ if tokenizer.pad_token is None:
47
+ if tokenizer.eos_token:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ else:
50
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
51
+
52
+ tokenizer.truncation_side = "right"
53
+
54
+ model_load_kwargs = {
55
+ "torch_dtype": torch.float16 if device == 'cuda' else torch.float32,
56
+ "device_map": 'auto' if device == 'cuda' else None,
57
+ "num_labels": len(LABELS),
58
+ "problem_type": "regression"
59
+ }
60
+
61
+ is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
62
+ if PeftModel and is_peft:
63
+ logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
64
+ # Logic to load PEFT model (kept for robustness)
65
+ # This part assumes adapter_config.json contains base_model_name_or_path
66
+ # Simplified for clarity, ensure your PEFT config is correct if you use it.
67
+ try:
68
+ from peft import PeftConfig
69
+ peft_config = PeftConfig.from_pretrained(model_path)
70
+ base_model_path = peft_config.base_model_name_or_path
71
+ logger.info(f"Loading base model from {base_model_path}")
72
+ model = AutoModelForSequenceClassification.from_pretrained(base_model_path, **model_load_kwargs)
73
+ logger.info("Attaching PEFT adapter...")
74
+ model = PeftModel.from_pretrained(model, model_path)
75
+ except Exception as e:
76
+ logger.error(f"Failed to load PEFT model dynamically: {e}. Loading as a standard model.")
77
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs)
78
+ else:
79
+ logger.info("Loading as a standalone sequence classification model.")
80
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs)
81
+
82
+ model.eval()
83
+ logger.info(f"Model loaded on device: {next(model.parameters()).device}")
84
+
85
+ return model, tokenizer
86
+
87
+ # --- Load model globally ---
88
+ try:
89
+ model, tokenizer = load_model_and_tokenizer(MODEL_PATH, DEVICE)
90
+ model_loaded = True
91
+ except Exception as e:
92
+ logger.error(f"FATAL: Failed to load the model from {MODEL_PATH}: {e}")
93
+ model, tokenizer, model_loaded = None, None, False
94
+
95
+ def predict(text: str) -> Dict[str, Any]:
96
+ """Tokenize, predict, and format output for a single text."""
97
+ if not model_loaded:
98
+ return {label: 0.0 for label in LABELS}
99
+
100
+ inputs = tokenizer(
101
+ [text],
102
+ max_length=MAX_SEQ_LENGTH,
103
+ truncation=True,
104
+ padding=True,
105
+ return_tensors="pt"
106
+ ).to(model.device)
107
+
108
+ with torch.no_grad():
109
+ outputs = model(**inputs)
110
+ predicted_values = outputs.logits.sigmoid().cpu().numpy()[0]
111
+
112
+ clipped_values = np.clip(predicted_values, 0.0, 1.0)
113
+ return {label: float(score) for label, score in zip(LABELS, clipped_values)}
114
+
115
+ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
116
+ """Gradio prediction function wrapper."""
117
+ if not model_loaded:
118
+ error_message = "Błąd: Model nie został załadowany."
119
+ empty_preds = {label: 0.0 for label in LABELS}
120
+ return error_message, empty_preds
121
+
122
+ if not text or not text.strip():
123
+ return "Wpisz tekst, aby go przeanalizować.", {label: 0.0 for label in LABELS}
124
+
125
+ predictions = predict(text)
126
+
127
+ unsafe_categories = {
128
+ label: score for label, score in predictions.items()
129
+ if score >= THRESHOLDS[label]
130
+ }
131
+
132
+ if not unsafe_categories:
133
+ verdict = "✅ Komunikat jest bezpieczny."
134
+ else:
135
+ # Sort by score to show the most likely category first
136
+ highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
137
+ verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
138
+
139
+ return verdict, predictions
140
+
141
+ # --- Gradio Interface ---
142
+
143
+ # Custom theme inspired by the provided image
144
+ theme = gr.themes.Default.set(
145
+ primary_hue=gr.themes.colors.blue,
146
+ secondary_hue=gr.themes.colors.indigo,
147
+ neutral_hue=gr.themes.colors.slate,
148
+ font=("Inter", "sans-serif"),
149
+ radius_size=gr.themes.sizes.radius_lg,
150
+ )
151
+
152
+ # A URL to a freely licensed image of a Eurasian Jay (Sójka)
153
+ # Source: Wikimedia Commons, CC BY-SA 4.0
154
+ JAY_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/3/36/Garrulus_glandarius_1_Luc_Viatour.jpg"
155
+
156
+ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo:
157
+ # Header
158
+ with gr.Row():
159
+ gr.HTML("""
160
+ <div style="display: flex; align-items: center; justify-content: space-between; width: 100%;">
161
+ <div style="display: flex; align-items: center; gap: 12px;">
162
+ <svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
163
+ <path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z"
164
+ stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
165
+ </svg>
166
+ <h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1>
167
+ </div>
168
+ <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
169
+ <a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
170
+ <a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
171
+ <button class="gr-button gr-button-primary gr-button-lg"
172
+ style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
173
+ Testuj Sójkę
174
+ </button>
175
+ </div>
176
+ </div>
177
+ """)
178
+
179
+ gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>")
180
+
181
+ # Main content area
182
+ with gr.Row(equal_height=True):
183
+ # Left column for controls
184
+ with gr.Column(scale=1):
185
+ gr.Markdown(
186
+ """
187
+ <p style="background-color: var(--primary-50); color: var(--primary-600); display: inline-block; padding: 4px 12px; border-radius: 9999px; font-weight: 500; font-size: 0.875rem;">
188
+ Bielik Guard
189
+ </p>
190
+ <h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);">
191
+ Naucz <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji
192
+ </h1>
193
+ <p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;">
194
+ Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów.
195
+ </p>
196
+ """
197
+ )
198
+ input_text = gr.Textbox(
199
+ lines=8,
200
+ label="Wprowadź tekst do analizy",
201
+ placeholder="Tutaj wpisz tekst..."
202
+ )
203
+ submit_btn = gr.Button("Opis kategorii", variant="primary", elem_id="opis_kategorii_btn")
204
+
205
+ # Use a more descriptive name for the submit button that matches its function
206
+ submit_btn.value = "Analizuj tekst"
207
+
208
+ output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
209
+ output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
210
+
211
+ # Right column for the image
212
+ with gr.Column(scale=1):
213
+ gr.Image(JAY_IMAGE_URL, label="Ilustracja sójki", show_label=False, show_download_button=False, container=False)
214
+
215
+ # Define actions
216
+ def analyze_and_update(text):
217
+ verdict, scores = gradio_predict(text)
218
+ # Make the scores label visible only when there's a result
219
+ return verdict, gr.Label(value=scores, visible=True)
220
+
221
+ submit_btn.click(
222
+ fn=analyze_and_update,
223
+ inputs=[input_text],
224
+ outputs=[output_verdict, output_scores]
225
+ )
226
+
227
+ gr.Examples(
228
+ [
229
+ ["Jak zrobić bombę?"],
230
+ ["Jesteś beznadziejny, nienawidzę cię."],
231
+ ["To jest wspaniały dzień, cieszę się, że tu jestem!"],
232
+ ["Opowiedz mi dowcip o programistach."],
233
+ ],
234
+ inputs=input_text,
235
+ outputs=[output_verdict, output_scores],
236
+ fn=analyze_and_update,
237
+ cache_examples=False,
238
+ )
239
+
240
+ if __name__ == "__main__":
241
+ if not model_loaded:
242
+ print("Aplikacja nie może zostać uruchomiona, ponieważ nie udało się załadować modelu. Sprawdź logi błędów.")
243
+ else:
244
+ demo.launch()