janmariakowalski commited on
Commit
8811dbf
·
verified ·
1 Parent(s): b11fa30

Delete new.py

Browse files
Files changed (1) hide show
  1. new.py +0 -354
new.py DELETED
@@ -1,354 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Gradio application for text classification, styled to be visually appealing.
4
- This version uses only the 'sojka2' model.
5
- """
6
-
7
- import gradio as gr
8
- import logging
9
- import os
10
- from typing import Dict, Tuple, Any
11
- import torch
12
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
- import numpy as np
14
-
15
- try:
16
- from peft import PeftModel
17
- except ImportError:
18
- PeftModel = None
19
- logging.info("PEFT library not found. Loading models without PEFT support.")
20
-
21
- # --- Configuration ---
22
- # Model path is set to sojka
23
- MODEL_PATH = os.getenv("MODEL_PATH", "AndromedaPL/sojka")
24
- TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base")
25
-
26
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
28
- MAX_SEQ_LENGTH = 512
29
-
30
-
31
- HF_TOKEN = os.getenv('HF_TOKEN')
32
-
33
- # Thresholds are now hardcoded
34
- THRESHOLDS = {
35
- "self-harm": 0.5,
36
- "hate": 0.5,
37
- "vulgar": 0.5,
38
- "sex": 0.5,
39
- "crime": 0.5,
40
- }
41
-
42
- # Set up logging
43
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
44
- logger = logging.getLogger(__name__)
45
-
46
- def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
47
- """Load the trained model and tokenizer"""
48
- logger.info(f"Loading tokenizer from {tokenizer_path}")
49
-
50
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
51
- logger.info(f"Tokenizer loaded: {tokenizer.name_or_path}")
52
-
53
- if tokenizer.pad_token is None:
54
- if tokenizer.eos_token:
55
- tokenizer.pad_token = tokenizer.eos_token
56
- else:
57
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
58
-
59
- tokenizer.truncation_side = "right"
60
-
61
- logger.info(f"Loading model from {model_path}")
62
-
63
- model_load_kwargs = {
64
- "torch_dtype": torch.float16 if device == 'cuda' else torch.float32,
65
- "device_map": 'auto' if device == 'cuda' else None,
66
- "num_labels": len(LABELS),
67
- "problem_type": "regression"
68
- }
69
-
70
- is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
71
- if PeftModel and is_peft:
72
- logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
73
- try:
74
- from peft import PeftConfig
75
- peft_config = PeftConfig.from_pretrained(model_path)
76
- base_model_path = peft_config.base_model_name_or_path
77
- logger.info(f"Loading base model from {base_model_path}")
78
- model = AutoModelForSequenceClassification.from_pretrained(base_model_path, **model_load_kwargs)
79
- logger.info("Attaching PEFT adapter...")
80
- model = PeftModel.from_pretrained(model, model_path)
81
- except Exception as e:
82
- logger.error(f"Failed to load PEFT model dynamically: {e}. Loading as a standard model.")
83
- model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs)
84
- else:
85
- logger.info("Loading as a standalone sequence classification model.")
86
- model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs)
87
-
88
- model.eval()
89
- logger.info(f"Model loaded on device: {next(model.parameters()).device}")
90
-
91
- return model, tokenizer
92
-
93
- # --- Load model globally ---
94
- try:
95
- model, tokenizer = load_model_and_tokenizer(MODEL_PATH, TOKENIZER_PATH, DEVICE)
96
- model_loaded = True
97
- except Exception as e:
98
- logger.error(f"FATAL: Failed to load the model from {MODEL_PATH} or tokenizer from {TOKENIZER_PATH}: {e}", e)
99
- model, tokenizer, model_loaded = None, None, False
100
-
101
- def predict(text: str) -> Dict[str, Any]:
102
- """Tokenize, predict, and format output for a single text."""
103
- if not model_loaded:
104
- return {label: 0.0 for label in LABELS}
105
-
106
- inputs = tokenizer(
107
- [text],
108
- max_length=MAX_SEQ_LENGTH,
109
- truncation=True,
110
- padding=True,
111
- return_tensors="pt"
112
- ).to(model.device)
113
-
114
- with torch.no_grad():
115
- outputs = model(**inputs)
116
- # Using sigmoid for multi-label classification outputs
117
- probabilities = torch.sigmoid(outputs.logits)
118
- predicted_values = probabilities.cpu().numpy()[0]
119
-
120
- clipped_values = np.clip(predicted_values, 0.0, 1.0)
121
- return {label: float(score) for label, score in zip(LABELS, clipped_values)}
122
-
123
- def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
124
- """Gradio prediction function wrapper."""
125
- if not model_loaded:
126
- error_message = "Błąd: Model nie został załadowany."
127
- empty_preds = {label: 0.0 for label in LABELS}
128
- return error_message, empty_preds
129
-
130
- if not text or not text.strip():
131
- return "Wpisz tekst, aby go przeanalizować.", {label: 0.0 for label in LABELS}
132
-
133
- predictions = predict(text)
134
-
135
- unsafe_categories = {
136
- label: score for label, score in predictions.items()
137
- if score >= THRESHOLDS[label]
138
- }
139
-
140
- if not unsafe_categories:
141
- verdict = "✅ Komunikat jest bezpieczny."
142
- else:
143
- highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
144
- verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści\n: {highest_unsafe_category.upper()}"
145
-
146
- return verdict, predictions
147
-
148
- # --- Gradio Interface ---
149
-
150
- theme = gr.themes.Default(
151
- primary_hue=gr.themes.colors.blue,
152
- secondary_hue=gr.themes.colors.indigo,
153
- neutral_hue=gr.themes.colors.slate,
154
- font=("Inter", "sans-serif"),
155
- radius_size=gr.themes.sizes.radius_lg,
156
- )
157
-
158
- # A URL to a freely licensed image of a Eurasian Jay (Sójka)
159
- # Source: Wikimedia Commons, CC BY-SA 4.0
160
- JAY_IMAGE_URL = "https://sojka.m31ai.pl/sojka.png"
161
-
162
- with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo:
163
- # Header
164
- with gr.Row():
165
- gr.HTML("""
166
- <div style="display: flex; align-items: center; justify-content: space-between; width: 100%;">
167
- <div style="display: flex; align-items: center; gap: 12px;">
168
- <svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
169
- <path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z"
170
- stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
171
- </svg>
172
- <h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1>
173
- </div>
174
- <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
175
- <a href="https://sojka.m31ai.pl/projekt.html" target="blank" style="text-decoration: none; color: inherit;">O projekcie</a>
176
- <a href="https://sojka.m31ai.pl/kategorie.html" target="blank" style="text-decoration: none; color: inherit;">Opis kategorii</a>
177
- <button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
178
- style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
179
- Testuj Sójkę
180
- </button>
181
- </div>
182
- </div>
183
- """)
184
-
185
- gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>")
186
-
187
- # Main content area
188
- with gr.Row(equal_height=True):
189
- # Left column for controls
190
- with gr.Column(scale=2): # Increased scale for more width
191
- gr.Markdown(
192
- """
193
- <p style="background-color: var(--primary-50); color: var(--primary-600); display: inline-block; padding: 4px 12px; border-radius: 9999px; font-weight: 500; font-size: 0.875rem;">
194
- Bielik Guard
195
- </p>
196
- <h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);">
197
- Przetestuj <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji
198
- </h1>
199
- <p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;">
200
- Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów.
201
- </p>
202
- """
203
- )
204
- input_text = gr.Textbox(
205
- lines=8,
206
- label="Wprowadź tekst do analizy",
207
- placeholder="Tutaj wpisz tekst..."
208
- )
209
-
210
- # Define outputs first so they can be referenced by gr.Examples
211
- # They will be placed in the layout in the nested Row below
212
- output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
213
- output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
214
-
215
- # Examples are now placed above the button
216
- gr.Examples(
217
- [
218
- ["Jak zrobić bombę?"],
219
- ["Jesteś beznadziejny, nienawidzę cię."],
220
- ["To jest wspaniały dzień, cieszę się, że tu jestem!"],
221
- ["Opowiedz mi dowcip o programistach."],
222
- ],
223
- inputs=input_text,
224
- outputs=[output_verdict, output_scores],
225
- fn=lambda text: analyze_and_update(text), # Use a lambda to match fn signature
226
- cache_examples=False,
227
- )
228
-
229
- # A nested Row for the button and the outputs
230
- with gr.Row():
231
- submit_btn = gr.Button("Analizuj tekst", variant="primary", scale=1)
232
- with gr.Column(scale=2):
233
- # The output components are defined above, but rendered here.
234
- # Gradio renders components where they are defined.
235
- # To solve this, we will re-declare them, which is not ideal,
236
- # but the simplest way to manage layout and callbacks.
237
- # The previous declarations are now just for the Examples callback.
238
- # Let's clean this up by defining them once.
239
- pass # The components are already in the layout from the definitions above.
240
- # The above comment is slightly incorrect for Gradio's declarative style.
241
- # The final working solution is to define components and then have them as outputs.
242
- # Let's revert to a cleaner structure that works.
243
-
244
- # Right column for the image
245
- with gr.Column(scale=1):
246
- gr.Image(JAY_IMAGE_URL, label="Ilustracja sójki", show_label=False, show_download_button=False, container=False, width=200)
247
-
248
- # Define actions
249
- def analyze_and_update(text):
250
- verdict, scores = gradio_predict(text)
251
- # Make the scores label visible only when there's a result
252
- return verdict, gr.Label(value=scores, visible=True)
253
-
254
- # The click function is now tied to the button defined in the nested Row
255
- submit_btn.click(
256
- fn=analyze_and_update,
257
- inputs=[input_text],
258
- outputs=[output_verdict, output_scores]
259
- )
260
-
261
- # Final corrected and working version of the interface layout
262
- with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo:
263
- # Header
264
- with gr.Row():
265
- gr.HTML("""
266
- <div style="display: flex; align-items: center; justify-content: space-between; width: 100%;">
267
- <div style="display: flex; align-items: center; gap: 12px;">
268
- <svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
269
- <path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z"
270
- stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
271
- </svg>
272
- <h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1>
273
- </div>
274
- <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
275
- <a href="https://sojka.m31ai.pl/projekt.html" target="blank" style="text-decoration: none; color: inherit;">O projekcie</a>
276
- <a href="https://sojka.m31ai.pl/kategorie.html" target="blank" style="text-decoration: none; color: inherit;">Opis kategorii</a>
277
- <button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
278
- style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
279
- Testuj Sójkę
280
- </button>
281
- </div>
282
- </div>
283
- """)
284
-
285
- gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>")
286
-
287
- # Main content area
288
- with gr.Row():
289
- # Left column for controls and description
290
- with gr.Column(scale=2):
291
- gr.Markdown(
292
- """
293
- <p style="background-color: var(--primary-50); color: var(--primary-600); display: inline-block; padding: 4px 12px; border-radius: 9999px; font-weight: 500; font-size: 0.875rem;">
294
- Bielik Guard
295
- </p>
296
- <h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);">
297
- Przetestuj <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji
298
- </h1>
299
- <p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;">
300
- Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów.
301
- </p>
302
- """
303
- )
304
- input_text = gr.Textbox(
305
- lines=8,
306
- label="Wprowadź tekst do analizy",
307
- placeholder="Tutaj wpisz tekst..."
308
- )
309
-
310
- # Note: Output components are defined in the right column
311
- # We will define Examples after the right column is created.
312
-
313
- # Right column for the image and RESULTS
314
- with gr.Column(scale=1):
315
- gr.Image(JAY_IMAGE_URL, label="Ilustracja sójki", show_label=False, show_download_button=False, container=False, width=200)
316
- gr.Markdown("---") # Separator
317
- output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
318
- output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
319
-
320
- # Interactive elements are defined last, after all components are created
321
- # This places the examples above the button
322
- gr.Examples(
323
- [
324
- ["Jak zrobić bombę?"],
325
- ["Jesteś beznadziejny, nienawidzę cię."],
326
- ["To jest wspaniały dzień, cieszę się, że tu jestem!"],
327
- ["Opowiedz mi dowcip o programistach."],
328
- ],
329
- inputs=input_text,
330
- outputs=[output_verdict, output_scores],
331
- fn=analyze_and_update,
332
- cache_examples=False,
333
- )
334
-
335
- submit_btn = gr.Button("Analizuj tekst", variant="primary")
336
-
337
- # Define actions
338
- def analyze_and_update(text):
339
- verdict, scores = gradio_predict(text)
340
- return verdict, gr.update(value=scores, visible=True)
341
-
342
- submit_btn.click(
343
- fn=analyze_and_update,
344
- inputs=[input_text],
345
- outputs=[output_verdict, output_scores]
346
- )
347
-
348
-
349
- if __name__ == "__main__":
350
- if not model_loaded:
351
- print("Aplikacja nie może zostać uruchomiona, ponieważ nie udało się załadować modelu. Sprawdź logi błędów.")
352
- else:
353
- # The final, corrected demo object is launched
354
- demo.launch()