janmariakowalski commited on
Commit
5347f4c
·
verified ·
1 Parent(s): b5fff5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -23,7 +23,6 @@ except ImportError:
23
  MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
26
- MAX_SEQ_LENGTH = 512
27
  # Thresholds are now hardcoded
28
  THRESHOLDS = {
29
  "self-harm": 0.5,
@@ -61,9 +60,6 @@ def load_model_and_tokenizer(model_path: str, device: str) -> Tuple[AutoModelFor
61
  is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
62
  if PeftModel and is_peft:
63
  logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
64
- # Logic to load PEFT model (kept for robustness)
65
- # This part assumes adapter_config.json contains base_model_name_or_path
66
- # Simplified for clarity, ensure your PEFT config is correct if you use it.
67
  try:
68
  from peft import PeftConfig
69
  peft_config = PeftConfig.from_pretrained(model_path)
@@ -107,7 +103,9 @@ def predict(text: str) -> Dict[str, Any]:
107
 
108
  with torch.no_grad():
109
  outputs = model(**inputs)
110
- predicted_values = outputs.logits.sigmoid().cpu().numpy()[0]
 
 
111
 
112
  clipped_values = np.clip(predicted_values, 0.0, 1.0)
113
  return {label: float(score) for label, score in zip(LABELS, clipped_values)}
@@ -132,7 +130,6 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
132
  if not unsafe_categories:
133
  verdict = "✅ Komunikat jest bezpieczny."
134
  else:
135
- # Sort by score to show the most likely category first
136
  highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
137
  verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
138
 
@@ -140,8 +137,8 @@ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
140
 
141
  # --- Gradio Interface ---
142
 
143
- # Custom theme inspired by the provided image
144
- theme = gr.themes.Default.set(
145
  primary_hue=gr.themes.colors.blue,
146
  secondary_hue=gr.themes.colors.indigo,
147
  neutral_hue=gr.themes.colors.slate,
@@ -168,7 +165,7 @@ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important;
168
  <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
169
  <a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
170
  <a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
171
- <button class="gr-button gr-button-primary gr-button-lg"
172
  style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
173
  Testuj Sójkę
174
  </button>
@@ -200,10 +197,7 @@ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important;
200
  label="Wprowadź tekst do analizy",
201
  placeholder="Tutaj wpisz tekst..."
202
  )
203
- submit_btn = gr.Button("Opis kategorii", variant="primary", elem_id="opis_kategorii_btn")
204
-
205
- # Use a more descriptive name for the submit button that matches its function
206
- submit_btn.value = "Analizuj tekst"
207
 
208
  output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
209
  output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
 
23
  MODEL_PATH = os.getenv("MODEL_PATH", "speakleash/sojka2")
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
 
26
  # Thresholds are now hardcoded
27
  THRESHOLDS = {
28
  "self-harm": 0.5,
 
60
  is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
61
  if PeftModel and is_peft:
62
  logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
 
 
 
63
  try:
64
  from peft import PeftConfig
65
  peft_config = PeftConfig.from_pretrained(model_path)
 
103
 
104
  with torch.no_grad():
105
  outputs = model(**inputs)
106
+ # Using sigmoid for multi-label classification outputs
107
+ probabilities = torch.sigmoid(outputs.logits)
108
+ predicted_values = probabilities.cpu().numpy()[0]
109
 
110
  clipped_values = np.clip(predicted_values, 0.0, 1.0)
111
  return {label: float(score) for label, score in zip(LABELS, clipped_values)}
 
130
  if not unsafe_categories:
131
  verdict = "✅ Komunikat jest bezpieczny."
132
  else:
 
133
  highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
134
  verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści w kategorii: {highest_unsafe_category.upper()}"
135
 
 
137
 
138
  # --- Gradio Interface ---
139
 
140
+ # Custom theme inspired by the provided image - THIS IS THE CORRECTED LINE
141
+ theme = gr.themes.Default(
142
  primary_hue=gr.themes.colors.blue,
143
  secondary_hue=gr.themes.colors.indigo,
144
  neutral_hue=gr.themes.colors.slate,
 
165
  <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
166
  <a href="#" style="text-decoration: none; color: inherit;">O projekcie</a>
167
  <a href="#" style="text-decoration: none; color: inherit;">Opis kategorii</a>
168
+ <button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
169
  style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
170
  Testuj Sójkę
171
  </button>
 
197
  label="Wprowadź tekst do analizy",
198
  placeholder="Tutaj wpisz tekst..."
199
  )
200
+ submit_btn = gr.Button("Analizuj tekst", variant="primary")
 
 
 
201
 
202
  output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
203
  output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)