Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,6 @@ except ImportError:
|
|
23 |
MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
|
24 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
|
26 |
-
MAX_SEQ_LENGTH = 512
|
27 |
# Thresholds are now hardcoded
|
28 |
THRESHOLDS = {
|
29 |
"self-harm": 0.5,
|
@@ -61,9 +60,6 @@ def load_model_and_tokenizer(model_path: str, device: str) -> Tuple[AutoModelFor
|
|
61 |
is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
|
62 |
if PeftModel and is_peft:
|
63 |
logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
|
64 |
-
# Logic to load PEFT model (kept for robustness)
|
65 |
-
# This part assumes adapter_config.json contains base_model_name_or_path
|
66 |
-
# Simplified for clarity, ensure your PEFT config is correct if you use it.
|
67 |
try:
|
68 |
from peft import PeftConfig
|
69 |
peft_config = PeftConfig.from_pretrained(model_path)
|
@@ -107,7 +103,9 @@ def predict(text: str) -> Dict[str, Any]:
|
|
107 |
|
108 |
with torch.no_grad():
|
109 |
outputs = model(**inputs)
|
110 |
-
|
|
|
|
|
111 |
|
112 |
clipped_values = np.clip(predicted_values, 0.0, 1.0)
|
113 |
return {label: float(score) for label, score in zip(LABELS, clipped_values)}
|
@@ -132,7 +130,6 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
|
|
132 |
if not unsafe_categories:
|
133 |
verdict = "✅ Komunikat jest bezpieczny."
|
134 |
else:
|
135 |
-
# Sort by score to show the most likely category first
|
136 |
highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
|
137 |
verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
|
138 |
|
@@ -140,8 +137,8 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
|
|
140 |
|
141 |
# --- Gradio Interface ---
|
142 |
|
143 |
-
# Custom theme inspired by the provided image
|
144 |
-
theme = gr.themes.Default
|
145 |
primary_hue=gr.themes.colors.blue,
|
146 |
secondary_hue=gr.themes.colors.indigo,
|
147 |
neutral_hue=gr.themes.colors.slate,
|
@@ -168,7 +165,7 @@ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important;
|
|
168 |
<div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
|
169 |
<a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
|
170 |
<a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
|
171 |
-
<button class="gr-button gr-button-primary gr-button-lg"
|
172 |
style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
|
173 |
Testuj Sójkę
|
174 |
</button>
|
@@ -200,10 +197,7 @@ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important;
|
|
200 |
label="Wprowadź tekst do analizy",
|
201 |
placeholder="Tutaj wpisz tekst..."
|
202 |
)
|
203 |
-
submit_btn = gr.Button("
|
204 |
-
|
205 |
-
# Use a more descriptive name for the submit button that matches its function
|
206 |
-
submit_btn.value = "Analizuj tekst"
|
207 |
|
208 |
output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
|
209 |
output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
|
|
|
23 |
MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
|
24 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
|
|
|
26 |
# Thresholds are now hardcoded
|
27 |
THRESHOLDS = {
|
28 |
"self-harm": 0.5,
|
|
|
60 |
is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
|
61 |
if PeftModel and is_peft:
|
62 |
logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
|
|
|
|
|
|
|
63 |
try:
|
64 |
from peft import PeftConfig
|
65 |
peft_config = PeftConfig.from_pretrained(model_path)
|
|
|
103 |
|
104 |
with torch.no_grad():
|
105 |
outputs = model(**inputs)
|
106 |
+
# Using sigmoid for multi-label classification outputs
|
107 |
+
probabilities = torch.sigmoid(outputs.logits)
|
108 |
+
predicted_values = probabilities.cpu().numpy()[0]
|
109 |
|
110 |
clipped_values = np.clip(predicted_values, 0.0, 1.0)
|
111 |
return {label: float(score) for label, score in zip(LABELS, clipped_values)}
|
|
|
130 |
if not unsafe_categories:
|
131 |
verdict = "✅ Komunikat jest bezpieczny."
|
132 |
else:
|
|
|
133 |
highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
|
134 |
verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
|
135 |
|
|
|
137 |
|
138 |
# --- Gradio Interface ---
|
139 |
|
140 |
+
# Custom theme inspired by the provided image - THIS IS THE CORRECTED LINE
|
141 |
+
theme = gr.themes.Default(
|
142 |
primary_hue=gr.themes.colors.blue,
|
143 |
secondary_hue=gr.themes.colors.indigo,
|
144 |
neutral_hue=gr.themes.colors.slate,
|
|
|
165 |
<div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
|
166 |
<a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
|
167 |
<a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
|
168 |
+
<button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
|
169 |
style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
|
170 |
Testuj Sójkę
|
171 |
</button>
|
|
|
197 |
label="Wprowadź tekst do analizy",
|
198 |
placeholder="Tutaj wpisz tekst..."
|
199 |
)
|
200 |
+
submit_btn = gr.Button("Analizuj tekst", variant="primary")
|
|
|
|
|
|
|
201 |
|
202 |
output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
|
203 |
output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
|