Update to latest Korean hate speech detection and mitigation system
Browse files
README.md
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: "4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
#
|
13 |
|
14 |
-
์ด Space๋ ํ๊ตญ์ด ํ์ค ํํ
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Hate Speech Mitigation Demo
|
3 |
+
emoji: ๐ก๏ธ
|
4 |
colorFrom: indigo
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: "4.27.0"
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Hate Speech Mitigation Demo
|
13 |
|
14 |
+
์ด Space๋ ํ๊ตญ์ด ํ์ค ํํ ์ํ ๋ฐ๋ชจ์
๋๋ค.
|
15 |
+
|
16 |
+
- Gradio ๊ธฐ๋ฐ ์ธํฐํ์ด์ค
|
17 |
+
- Electra + CRF ๊ธฐ๋ฐ ํ์ค ํ์ง
|
18 |
+
- LLM ๊ธฐ๋ฐ ๋ฌธ์ฅ ์ํ
|
app.py
CHANGED
@@ -1,5 +1,13 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch.nn as nn
|
4 |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
|
5 |
import numpy as np
|
@@ -8,636 +16,292 @@ from TorchCRF import CRF
|
|
8 |
|
9 |
from bert_score import score as bert_score_fn
|
10 |
import re
|
11 |
-
from huggingface_hub import hf_hub_download
|
12 |
|
|
|
|
|
|
|
13 |
def calc_bertscore(orig_text, rewritten_text):
|
14 |
P, R, F1 = bert_score_fn([rewritten_text], [orig_text], lang="ko")
|
15 |
return round(F1[0].item(), 3)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
def calc_ppl(text):
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
return 1.0
|
31 |
|
|
|
32 |
def calc_toxicity_reduction(orig_text, rewritten_text, detector_model, detector_tokenizer):
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
orig_attention_mask = orig_enc["attention_mask"].to(device)
|
39 |
-
with torch.no_grad():
|
40 |
-
orig_out = detector_model(input_ids=orig_input_ids, attention_mask=orig_attention_mask)
|
41 |
-
orig_logits = orig_out["sentence_logits"][0]
|
42 |
-
orig_probs = torch.softmax(orig_logits, dim=-1)
|
43 |
-
orig_toxicity = 1.0 - orig_probs[0].item()
|
44 |
-
# Rewritten toxicity score
|
45 |
-
rewritten_enc = detector_tokenizer(rewritten_text, return_tensors="pt", padding="max_length", max_length=128)
|
46 |
-
rewritten_input_ids = rewritten_enc["input_ids"].to(device)
|
47 |
-
rewritten_attention_mask = rewritten_enc["attention_mask"].to(device)
|
48 |
-
with torch.no_grad():
|
49 |
-
rewritten_out = detector_model(input_ids=rewritten_input_ids, attention_mask=rewritten_attention_mask)
|
50 |
-
rewritten_logits = rewritten_out["sentence_logits"][0]
|
51 |
-
rewritten_probs = torch.softmax(rewritten_logits, dim=-1)
|
52 |
-
rewritten_toxicity = 1.0 - rewritten_probs[0].item()
|
53 |
-
delta = orig_toxicity - rewritten_toxicity
|
54 |
-
return round(delta, 3)
|
55 |
-
except Exception as e:
|
56 |
-
print(f"Toxicity reduction calculation error: {e}")
|
57 |
-
return 0.0
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
bio_loss = -log_likelihood
|
83 |
-
tgt_dropped = self.dropout(pooled_output)
|
84 |
-
target_logits = self.target_head(tgt_dropped)
|
85 |
-
loss = 0.0
|
86 |
-
if sentence_labels is not None:
|
87 |
-
cls_loss = nn.CrossEntropyLoss()(sentence_logits, sentence_labels)
|
88 |
-
loss += cls_loss
|
89 |
-
if bio_loss is not None:
|
90 |
-
loss += bio_loss.sum()
|
91 |
-
if targets is not None:
|
92 |
-
bce_loss = nn.BCEWithLogitsLoss()(target_logits, targets)
|
93 |
-
loss += 2.0 * bce_loss
|
94 |
-
# CRF decode
|
95 |
-
if bio_tags is not None:
|
96 |
-
decode_mask = bio_tags != -100
|
97 |
-
else:
|
98 |
-
decode_mask = attention_mask.bool()
|
99 |
-
print("[DEBUG] bio_tags:", bio_tags)
|
100 |
-
print("[DEBUG] attention_mask.shape:", attention_mask.shape)
|
101 |
-
print("[DEBUG] decode_mask.shape:", decode_mask.shape)
|
102 |
-
print("[DEBUG] decode_mask[:, 0]:", decode_mask[:, 0] if decode_mask.dim() > 1 else decode_mask[0])
|
103 |
-
print("[DEBUG] bio_feats.shape:", bio_feats.shape)
|
104 |
-
bio_preds = self.crf.viterbi_decode(bio_feats, mask=decode_mask)
|
105 |
-
return {
|
106 |
-
'loss': loss,
|
107 |
-
'sentence_logits': sentence_logits,
|
108 |
-
'bio_logits': bio_feats,
|
109 |
-
'bio_preds': bio_preds,
|
110 |
-
'target_logits': target_logits
|
111 |
-
}
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
136 |
-
state_dict = checkpoint['model_state_dict']
|
137 |
-
else:
|
138 |
-
state_dict = checkpoint
|
139 |
-
for k, v in state_dict.items():
|
140 |
-
new_key = key_map.get(k, k)
|
141 |
-
new_state_dict[new_key] = v
|
142 |
-
self.model.load_state_dict(new_state_dict, strict=True)
|
143 |
-
self.model.to(self.device)
|
144 |
-
self.model.eval()
|
145 |
-
|
146 |
-
# Blossom LLM loading
|
147 |
-
print("Blossom LLM loading...")
|
148 |
-
self.llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-3B"
|
149 |
-
self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
|
150 |
-
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
151 |
-
self.llm_model_name,
|
152 |
-
torch_dtype=torch.bfloat16,
|
153 |
-
device_map="auto"
|
154 |
-
)
|
155 |
-
print("LLM loading complete!")
|
156 |
-
|
157 |
-
self.label_names = ["normal", "offensive", "L1_hate", "L2_hate"]
|
158 |
-
self.bio_names = {0: "O", 1: "B-SOFT", 2: "I-SOFT", 3: "B-HARD", 4: "I-HARD"}
|
159 |
-
|
160 |
-
val_acc = checkpoint['val_acc'] if 'val_acc' in checkpoint else None
|
161 |
-
if val_acc is not None:
|
162 |
-
print(f"Model loaded - Validation accuracy: {val_acc:.2f}%")
|
163 |
-
else:
|
164 |
-
print("Model loaded - Validation accuracy: N/A")
|
165 |
-
|
166 |
-
def detect_hate_speech(self, text, strategy="Detection Only"):
|
167 |
-
"""Hate Speech Detection and Mitigation"""
|
168 |
-
if not text.strip():
|
169 |
-
return "Please enter text", ""
|
170 |
-
if len(text.strip()) < 2:
|
171 |
-
return "Input text is too short. Please enter at least 2 characters.", ""
|
172 |
-
|
173 |
-
if strategy == "Detection Only":
|
174 |
-
result_msg, mitigation, debug_info = self._detection_only(text)
|
175 |
-
print("[DEBUG] Input text:", text)
|
176 |
-
print("[DEBUG] sentence_logits:", debug_info.get('sentence_logits'))
|
177 |
-
print("[DEBUG] sentence_probs:", debug_info.get('sentence_probs'))
|
178 |
-
print("[DEBUG] sentence_pred:", debug_info.get('sentence_pred'))
|
179 |
-
print("[DEBUG] label:", debug_info.get('label'))
|
180 |
-
print("[DEBUG] confidence:", debug_info.get('confidence'))
|
181 |
-
return result_msg, mitigation
|
182 |
-
elif strategy == "Guided":
|
183 |
-
return self._guided_mitigation(text)
|
184 |
-
elif strategy == "Guided+Reflect":
|
185 |
-
return self._guided_reflect_mitigation(text)
|
186 |
-
elif strategy == "Unguided":
|
187 |
-
return self._unguided_mitigation(text)
|
188 |
-
else:
|
189 |
-
return "Invalid strategy", ""
|
190 |
-
|
191 |
-
def _detection_only(self, text):
|
192 |
-
"""Perform only detection (existing logic)"""
|
193 |
-
# Tokenization
|
194 |
-
encoding = self.tokenizer(
|
195 |
-
text,
|
196 |
-
truncation=True,
|
197 |
-
padding="max_length",
|
198 |
-
max_length=128,
|
199 |
-
return_attention_mask=True,
|
200 |
-
return_tensors="pt"
|
201 |
-
)
|
202 |
-
|
203 |
-
input_ids = encoding["input_ids"].to(self.device)
|
204 |
-
attention_mask = encoding["attention_mask"].to(self.device)
|
205 |
-
print("[DEBUG] attention_mask[:, 0] =", attention_mask[:, 0])
|
206 |
-
|
207 |
-
# Prediction
|
208 |
-
with torch.no_grad():
|
209 |
-
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
210 |
-
sentence_logits = outputs["sentence_logits"]
|
211 |
-
bio_logits = outputs["bio_logits"]
|
212 |
-
|
213 |
-
# Sentence classification result
|
214 |
-
sentence_probs = torch.softmax(sentence_logits, dim=1)
|
215 |
-
sentence_pred = torch.argmax(sentence_logits, dim=1).item()
|
216 |
-
sentence_prob = sentence_probs[0][sentence_pred].item()
|
217 |
-
|
218 |
-
# BIO tagging result
|
219 |
-
bio_preds = torch.argmax(bio_logits, dim=2)[0]
|
220 |
-
|
221 |
-
# Find hate/aggressive tokens
|
222 |
-
hate_tokens = []
|
223 |
-
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
|
224 |
-
|
225 |
-
# Tokenize original text to get offset mapping
|
226 |
-
tokenized = self.tokenizer(
|
227 |
-
text,
|
228 |
-
truncation=True,
|
229 |
-
padding="max_length",
|
230 |
-
max_length=128,
|
231 |
-
return_offsets_mapping=True
|
232 |
-
)
|
233 |
-
offset_mapping = tokenized["offset_mapping"]
|
234 |
-
|
235 |
-
for j, (token, pred) in enumerate(zip(tokens, bio_preds)):
|
236 |
-
if pred.item() != 0: # Not O
|
237 |
-
# Extract the corresponding part from the original text using offset mapping
|
238 |
-
if j < len(offset_mapping):
|
239 |
-
start, end = offset_mapping[j]
|
240 |
-
if start != end: # Token mapped to actual text
|
241 |
-
original_text = text[start:end]
|
242 |
-
hate_tokens.append((j, original_text, self.bio_names[pred.item()]))
|
243 |
-
else:
|
244 |
-
# Special token handling
|
245 |
-
if token.startswith('ฤ '):
|
246 |
-
decoded_token = token[1:] # Remove ฤ
|
247 |
-
elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
|
248 |
-
decoded_token = token
|
249 |
-
else:
|
250 |
-
decoded_token = token
|
251 |
-
hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
|
252 |
-
else:
|
253 |
-
# Fallback
|
254 |
-
if token.startswith('ฤ '):
|
255 |
-
decoded_token = token[1:]
|
256 |
-
elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
|
257 |
-
decoded_token = token
|
258 |
-
else:
|
259 |
-
decoded_token = token
|
260 |
-
hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
|
261 |
-
|
262 |
-
# Determine label
|
263 |
-
label = self.label_names[sentence_pred]
|
264 |
-
# If hate_tokens contain B-HARD, I-HARD, increase label to L2_hate
|
265 |
-
if any(bio_label in ["B-HARD", "I-HARD"] for _, _, bio_label in hate_tokens):
|
266 |
-
label = "L2_hate"
|
267 |
-
# Construct result message
|
268 |
-
result_msg = f"Detection result: {label}\nConfidence: {sentence_prob:.2f}"
|
269 |
-
if hate_tokens:
|
270 |
-
result_msg += f"\nIdentified hate/aggressive expressions: {hate_tokens}"
|
271 |
-
mitigation = "Performed only detection."
|
272 |
-
debug_info = {
|
273 |
-
'sentence_logits': sentence_logits,
|
274 |
-
'sentence_probs': sentence_probs,
|
275 |
-
'sentence_pred': sentence_pred,
|
276 |
-
'label': label,
|
277 |
-
'confidence': sentence_prob,
|
278 |
-
'hate_tokens': hate_tokens
|
279 |
-
}
|
280 |
-
return result_msg, mitigation, debug_info
|
281 |
-
|
282 |
-
def _unguided_mitigation(self, text):
|
283 |
-
"""Unguided Mode: Only Using Generation"""
|
284 |
-
try:
|
285 |
-
# Blossom LLM prompt
|
286 |
-
prompt = f"""Please remove hate speech or aggressive expressions from the following sentence, while maintaining the original intent (criticism, complaint, opinion, etc.).
|
287 |
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
if clean_lines:
|
329 |
-
mitigated_text = clean_lines[0]
|
330 |
-
|
331 |
-
# Result message
|
332 |
-
result_msg = f"๐ค **Blossom LLM Mitigation Result**\n\n"
|
333 |
-
result_msg += f"**Original:** {text}\n\n"
|
334 |
-
result_msg += f"**Mitigated Sentence:** {mitigated_text}"
|
335 |
-
|
336 |
-
# Mitigation info
|
337 |
-
mitigation = "**Unguided Mode:** Blossom LLM detected and mitigated harmful expressions autonomously."
|
338 |
-
|
339 |
-
return result_msg, mitigation
|
340 |
-
|
341 |
-
except Exception as e:
|
342 |
-
error_msg = f"โ **Blossom LLM Error**\n\nError occurred: {str(e)}"
|
343 |
-
return error_msg, "An error occurred during LLM processing."
|
344 |
-
|
345 |
-
def _guided_mitigation(self, text):
|
346 |
-
"""Guided Mode: Mitigate based on KcELECTRA detection result using Blossom LLM"""
|
347 |
-
try:
|
348 |
-
# First, perform detection with KcELECTRA
|
349 |
-
detection_result, _, debug_info = self._detection_only(text)
|
350 |
-
label = debug_info.get('label', 'normal')
|
351 |
-
hate_tokens = debug_info.get('hate_tokens', [])
|
352 |
-
|
353 |
-
# Construct Blossom LLM prompt
|
354 |
-
if label == "normal":
|
355 |
-
prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
|
356 |
-
else:
|
357 |
-
label_desc = {
|
358 |
-
"offensive": "Aggressive",
|
359 |
-
"L1_hate": "Mild Hate",
|
360 |
-
"L2_hate": "Severe Hate"
|
361 |
-
}
|
362 |
-
hate_tokens_str = ""
|
363 |
-
if hate_tokens:
|
364 |
-
hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"โข {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
|
365 |
-
prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nPlease remove hate speech or aggressive expressions, while maintaining the original intent (criticism, complaint, opinion, etc.).\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\n[Important] All offensive, derogatory, and explicit hate expressions (e.g., ์จ๋ฐ, ์ข, ๋ณ์ ) must be deleted.\n\nMitigated sentence:"""
|
366 |
-
# LLM inference
|
367 |
-
inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
|
368 |
-
with torch.no_grad():
|
369 |
-
outputs = self.llm_model.generate(
|
370 |
-
**inputs,
|
371 |
-
do_sample=True,
|
372 |
-
top_k=50,
|
373 |
-
top_p=0.9,
|
374 |
-
max_new_tokens=300,
|
375 |
-
pad_token_id=self.llm_tokenizer.pad_token_id,
|
376 |
-
eos_token_id=self.llm_tokenizer.eos_token_id
|
377 |
-
)
|
378 |
-
full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
379 |
-
mitigated_text = full_response.replace(prompt, "").strip()
|
380 |
-
if len(mitigated_text) < 10:
|
381 |
-
mitigated_text = full_response
|
382 |
-
if "Mitigated sentence:" in mitigated_text:
|
383 |
-
mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
|
384 |
-
lines = mitigated_text.split('\n')
|
385 |
-
clean_lines = []
|
386 |
-
for line in lines:
|
387 |
-
line = line.strip()
|
388 |
-
if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
|
389 |
-
clean_lines.append(line)
|
390 |
-
if clean_lines:
|
391 |
-
mitigated_text = clean_lines[0]
|
392 |
-
result_msg = f"๐ฏ **Guided Mitigation Result**\n\n"
|
393 |
-
result_msg += f"**KcELECTRA Detection Result:**\n{detection_result}\n\n"
|
394 |
-
result_msg += f"**Blossom LLM Mitigation Result:**\n{mitigated_text}"
|
395 |
-
mitigation = "**Guided Mode:** Blossom LLM performed specific mitigation based on KcELECTRA's detection information."
|
396 |
-
return result_msg, mitigation
|
397 |
-
except Exception as e:
|
398 |
-
error_msg = f"โ **Guided Mitigation Error**\n\nError occurred: {str(e)}"
|
399 |
-
return error_msg, "An error occurred during guided mitigation processing."
|
400 |
-
|
401 |
-
def _guided_reflect_mitigation(self, text):
|
402 |
-
"""Guided+Reflect Mode: iterative refinement + critic evaluation"""
|
403 |
-
try:
|
404 |
-
detection_result, _, debug_info = self._detection_only(text)
|
405 |
-
label = debug_info.get('label', 'normal')
|
406 |
-
hate_tokens = debug_info.get('hate_tokens', [])
|
407 |
-
# Step 1: Initial mitigation
|
408 |
-
if label == "normal":
|
409 |
-
initial_prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
|
410 |
-
else:
|
411 |
-
label_desc = {
|
412 |
-
"offensive": "Aggressive",
|
413 |
-
"L1_hate": "Mild Hate",
|
414 |
-
"L2_hate": "Severe Hate"
|
415 |
-
}
|
416 |
-
hate_tokens_str = ""
|
417 |
-
if hate_tokens:
|
418 |
-
hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"โข {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
|
419 |
-
initial_prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nExpressions containing offensive words (e.g., ์ข, ์จ๋ฐ, ๋ณ์ ) must be deleted.\nOther aggressive or inappropriate expressions should be mitigated by expressing them more politely and inclusively.\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\nMitigated sentence:"""
|
420 |
-
# Iterative mitigation and evaluation
|
421 |
-
max_iter = 5
|
422 |
-
metrics_history = []
|
423 |
-
best_candidate = None
|
424 |
-
best_score = -float('inf')
|
425 |
-
current_input = text
|
426 |
-
for i in range(max_iter):
|
427 |
-
# Generate candidate
|
428 |
-
inputs = self.llm_tokenizer(initial_prompt, return_tensors="pt").to(self.llm_model.device)
|
429 |
-
with torch.no_grad():
|
430 |
-
outputs = self.llm_model.generate(
|
431 |
-
**inputs,
|
432 |
-
do_sample=True,
|
433 |
-
top_k=50,
|
434 |
-
top_p=0.9,
|
435 |
-
max_new_tokens=300,
|
436 |
-
pad_token_id=self.llm_tokenizer.pad_token_id,
|
437 |
-
eos_token_id=self.llm_tokenizer.eos_token_id
|
438 |
-
)
|
439 |
-
candidate = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
440 |
-
mitigated_text = candidate.replace(initial_prompt, "").strip()
|
441 |
-
if len(mitigated_text) < 10:
|
442 |
-
mitigated_text = candidate
|
443 |
-
if "Mitigated sentence:" in mitigated_text:
|
444 |
-
mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
|
445 |
-
lines = mitigated_text.split('\n')
|
446 |
-
clean_lines = []
|
447 |
-
for line in lines:
|
448 |
-
line = line.strip()
|
449 |
-
if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
|
450 |
-
clean_lines.append(line)
|
451 |
-
if clean_lines:
|
452 |
-
mitigated_text = clean_lines[0]
|
453 |
-
# Exclude candidates containing offensive words
|
454 |
-
if contains_badword(mitigated_text):
|
455 |
-
continue
|
456 |
-
# Evaluation
|
457 |
-
toxicity = calc_toxicity_reduction(text, mitigated_text, self.model, self.tokenizer)
|
458 |
-
bertscore = calc_bertscore(text, mitigated_text)
|
459 |
-
ppl = calc_ppl(mitigated_text)
|
460 |
-
metrics_history.append({'iteration': i+1, 'candidate': mitigated_text, 'toxicity': toxicity, 'bertscore': bertscore, 'ppl': ppl})
|
461 |
-
# Simple combined score (weight adjustment possible)
|
462 |
-
total_score = toxicity + bertscore - ppl * 0.01
|
463 |
-
if total_score > best_score:
|
464 |
-
best_score = total_score
|
465 |
-
best_candidate = mitigated_text
|
466 |
-
# Early termination criteria (e.g., toxicity>0.3, bertscore>0.7, ppl<100)
|
467 |
-
if toxicity > 0.3 and bertscore > 0.7 and ppl < 100:
|
468 |
-
break
|
469 |
-
# Log output
|
470 |
-
iter_log_str = ""
|
471 |
-
for log in metrics_history:
|
472 |
-
iter_log_str += f"\nIteration {log['iteration']}:\n- Candidate: {log['candidate']}\n- Toxicity reduction: {log['toxicity']}, bertscore: {log['bertscore']}, ppl: {log['ppl']}"
|
473 |
-
# Result message
|
474 |
-
result_msg = f"๐ **Guided+Reflect Mitigation Result**\n\n"
|
475 |
-
result_msg += f"**Detection Result:**\n{detection_result}\n\n"
|
476 |
-
result_msg += f"**Iterative Mitigation Log:**{iter_log_str}\n\n"
|
477 |
-
result_msg += f"**Best Mitigation:** {best_candidate}"
|
478 |
-
mitigation = "**Guided+Reflect Mode:** Selected the optimal candidate after iterative mitigation and evaluation (maximum 5 iterations)."
|
479 |
-
return result_msg, mitigation
|
480 |
-
except Exception as e:
|
481 |
-
error_msg = f"โ **Guided+Reflect Mitigation Error**\n\nError occurred: {str(e)}"
|
482 |
-
return error_msg, "An error occurred during guided+reflect mitigation processing."
|
483 |
-
|
484 |
-
def _suggest_mitigation(self, label, confidence, hate_tokens):
|
485 |
-
"""Suggest mitigation for hate speech expressions"""
|
486 |
-
if label == "normal":
|
487 |
-
return "โ
**Mitigation Suggestion**: This sentence does not require correction."
|
488 |
-
|
489 |
-
mitigation = f"**๐ง Mitigation Suggestion for Hate Speech:**\n\n"
|
490 |
-
|
491 |
-
if label == "offensive":
|
492 |
-
mitigation += "**Aggressive Expression Mitigation Options:**\n"
|
493 |
-
mitigation += "โข Try to change aggressive expressions to more polite expressions\n"
|
494 |
-
mitigation += "โข Use objective expressions instead of emotional expressions\n"
|
495 |
-
mitigation += "โข Reconstruct with a mind to be considerate\n"
|
496 |
-
mitigation += "โข When criticizing, provide specific and constructive feedback"
|
497 |
-
elif label == "L1_hate":
|
498 |
-
mitigation += "**Implicit Hate Expression Mitigation Options:**\n"
|
499 |
-
mitigation += "โข Remove expressions that discriminate or show prejudice\n"
|
500 |
-
mitigation += "โข Avoid generalizing about specific groups\n"
|
501 |
-
mitigation += "โข Use more inclusive and respectful expressions\n"
|
502 |
-
mitigation += "โข Change to expressions that acknowledge diversity"
|
503 |
-
else: # L2_hate
|
504 |
-
mitigation += "**Explicit Hate Expression Mitigation Options:**\n"
|
505 |
-
mitigation += "โข Completely remove severe hate expressions\n"
|
506 |
-
mitigation += "โข Do not use violent or threatening expressions\n"
|
507 |
-
mitigation += "โข Use expressions that respect everyone's dignity\n"
|
508 |
-
mitigation += "โข Change to expressions that discriminate or promote hate\n"
|
509 |
-
mitigation += "โข If necessary, seek professional help"
|
510 |
-
|
511 |
-
return mitigation
|
512 |
|
513 |
-
def
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
-
|
518 |
-
|
|
|
|
|
|
|
519 |
|
520 |
-
|
521 |
-
def create_demo():
|
522 |
-
with gr.Blocks(
|
523 |
-
title="Korean Hate Speech Detection and Mitigation System",
|
524 |
-
theme=gr.themes.Soft(),
|
525 |
-
css="""
|
526 |
-
.gradio-container {
|
527 |
-
max-width: 800px;
|
528 |
-
margin: 0 auto;
|
529 |
-
}
|
530 |
-
.result-box {
|
531 |
-
border-radius: 10px;
|
532 |
-
padding: 15px;
|
533 |
-
margin: 10px 0;
|
534 |
-
}
|
535 |
-
.normal { background-color: #d4edda; border: 1px solid #c3e6cb; }
|
536 |
-
.offensive { background-color: #fff3cd; border: 1px solid #ffeaa7; }
|
537 |
-
.hate { background-color: #f8d7da; border: 1px solid #f5c6cb; }
|
538 |
-
"""
|
539 |
-
) as demo:
|
540 |
-
gr.Markdown("""
|
541 |
-
# Korean Hate Speech Detection and Mitigation System
|
542 |
-
|
543 |
-
This system detects hate speech in Korean text and provides mitigation suggestions.
|
544 |
-
|
545 |
-
|
546 |
-
**๐ข Normal**:
|
547 |
-
- It is a normal sentence.
|
548 |
-
|
549 |
-
**๐ก Offensive**
|
550 |
-
|
551 |
-
- For example: "Don't say such a stupid thing", "How can you do such a stupid thing"
|
552 |
-
|
553 |
-
**๐ L1_hate (Implicit Hate)**: Mild hate expression
|
554 |
-
- **Implicit hate expression** for protected attribute groups
|
555 |
-
- For example: "Those people are all the same", "Prejudicial expression towards a specific group"
|
556 |
-
|
557 |
-
**๐ด L2_hate (Explicit Hate)**: Severe hate expression
|
558 |
-
- **Explicit hate expression** for protected attribute groups
|
559 |
-
|
560 |
-
**๐ค Mitigation Mode:**
|
561 |
-
- ๐ **Detection Only**: Hate Speech Detection Only
|
562 |
-
- ๐ฏ **Guided**: Guided Mitigation
|
563 |
-
- ๐ **Guided+Reflect**: After Guided Mitigation, Iterative Refinement
|
564 |
-
- ๐ค **Unguided**: LLM generates text without any guidance
|
565 |
-
""")
|
566 |
-
|
567 |
-
with gr.Row():
|
568 |
-
with gr.Column(scale=2):
|
569 |
-
input_text = gr.Textbox(
|
570 |
-
label="Enter text to detect hate speech & mitigate",
|
571 |
-
lines=3
|
572 |
-
)
|
573 |
-
|
574 |
-
strategy = gr.Radio(
|
575 |
-
["Detection Only", "Guided", "Guided+Reflect", "Unguided"],
|
576 |
-
value="Detection Only",
|
577 |
-
label="Select Mitigation Mode",
|
578 |
-
container=True
|
579 |
-
)
|
580 |
-
|
581 |
-
analyze_btn = gr.Button("๐ Detect & Mitigate", variant="primary", size="lg")
|
582 |
-
|
583 |
-
# with gr.Column(scale=1):
|
584 |
-
# gr.Markdown("""
|
585 |
-
# **๐งช Test Examples:**
|
586 |
-
|
587 |
-
# **๐ข Normal:**
|
588 |
-
# - "Hello, today's weather is nice."
|
589 |
-
# - "This movie was really fun."
|
590 |
-
|
591 |
-
# **๐ก Offensive:**
|
592 |
-
# - "How can you do such a stupid thing"
|
593 |
-
# - "Don't say such a stupid thing"
|
594 |
-
|
595 |
-
# **๐ L1_hate (Implicit):**
|
596 |
-
# - "Those people are all the same"
|
597 |
-
# - "Prejudicial expression towards a specific group"
|
598 |
-
|
599 |
-
# **๐ด L2_hate (Explicit):**
|
600 |
-
# - "All women are useless"
|
601 |
-
# - "People with disabilities are a burden to society"
|
602 |
-
# """)
|
603 |
-
|
604 |
-
with gr.Row():
|
605 |
-
with gr.Column():
|
606 |
-
result_output = gr.Markdown(
|
607 |
-
label="Mitigation Button",
|
608 |
-
value="Input text and click the above button."
|
609 |
-
)
|
610 |
-
|
611 |
-
with gr.Column():
|
612 |
-
mitigation_output = gr.Markdown(
|
613 |
-
label="Mitigation Suggestion",
|
614 |
-
value="Based on the analysis result, mitigation suggestions will be provided."
|
615 |
-
)
|
616 |
-
|
617 |
-
# Event handlers
|
618 |
-
analyze_btn.click(
|
619 |
-
fn=service.detect_hate_speech,
|
620 |
-
inputs=[input_text, strategy],
|
621 |
-
outputs=[result_output, mitigation_output]
|
622 |
-
)
|
623 |
-
|
624 |
-
# Allow analysis via Enter key
|
625 |
-
input_text.submit(
|
626 |
-
fn=service.detect_hate_speech,
|
627 |
-
inputs=[input_text, strategy],
|
628 |
-
outputs=[result_output, mitigation_output]
|
629 |
-
)
|
630 |
-
|
631 |
-
# gr.Markdown("""
|
632 |
-
# ---
|
633 |
-
# **Model Information:**
|
634 |
-
# - Detection Model: KcELECTRA-base (Validation Accuracy: 67.67%)
|
635 |
-
# - Mitigation Model: Blossom LLM (llama-3.2-Korean-Bllossom-3B)
|
636 |
-
# - Training Data: K-HATERS Dataset
|
637 |
-
# """)
|
638 |
-
|
639 |
-
return demo
|
640 |
|
641 |
-
|
642 |
-
|
643 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
<<<<<<< HEAD
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from detector import detect_spans, KHatersModelCRF
|
6 |
+
from generator import critic_score, mitigate_with_strategy
|
7 |
+
from bert_score import score
|
8 |
+
from huggingface_hub import hf_hub_download # ์ถ๊ฐ
|
9 |
+
|
10 |
+
=======
|
11 |
import torch.nn as nn
|
12 |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
|
13 |
import numpy as np
|
|
|
16 |
|
17 |
from bert_score import score as bert_score_fn
|
18 |
import re
|
|
|
19 |
|
20 |
+
>>>>>>> 1094316 (Update Gradio demo for Hugging Face Space)
|
21 |
+
# --- ์ค์ฝ์ด ๊ณ์ฐ์ฉ placeholder ํจ์๋ค ---
|
22 |
+
# 1. BERTScore (์๋ฏธ ์ ์ฌ๋)
|
23 |
def calc_bertscore(orig_text, rewritten_text):
|
24 |
P, R, F1 = bert_score_fn([rewritten_text], [orig_text], lang="ko")
|
25 |
return round(F1[0].item(), 3)
|
26 |
|
27 |
+
# 2. PPL (Perplexity, ์ธ์ด์ ์์ฐ์ค๋ฌ์)
|
28 |
+
# KoGPT2 ๋ฑ ํ๊ตญ์ด causal LM ์ฌ์ฉ (์ต์ด 1ํ๋ง ๋ชจ๋ธ/ํ ํฌ๋์ด์ ๋ก๋)
|
29 |
+
_kogpt2_tokenizer = None
|
30 |
+
_kogpt2_model = None
|
31 |
+
|
32 |
def calc_ppl(text):
|
33 |
+
global _kogpt2_tokenizer, _kogpt2_model
|
34 |
+
if _kogpt2_tokenizer is None or _kogpt2_model is None:
|
35 |
+
_kogpt2_tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2")
|
36 |
+
_kogpt2_model = AutoModelForCausalLM.from_pretrained("skt/kogpt2-base-v2")
|
37 |
+
_kogpt2_model.eval()
|
38 |
+
encodings = _kogpt2_tokenizer(text, return_tensors="pt")
|
39 |
+
input_ids = encodings.input_ids
|
40 |
+
with torch.no_grad():
|
41 |
+
outputs = _kogpt2_model(input_ids, labels=input_ids)
|
42 |
+
loss = outputs.loss
|
43 |
+
ppl = torch.exp(loss).item()
|
44 |
+
return round(ppl, 3)
|
|
|
45 |
|
46 |
+
# 3. ฮTox (์ ํด์ฑ ๊ฐ์)
|
47 |
def calc_toxicity_reduction(orig_text, rewritten_text, detector_model, detector_tokenizer):
|
48 |
+
# critic_score ํจ์๋ก normal softmax score ํ๊ฐ
|
49 |
+
orig_score = critic_score(orig_text, detector_model, detector_tokenizer)
|
50 |
+
rewritten_score = critic_score(rewritten_text, detector_model, detector_tokenizer)
|
51 |
+
delta = orig_score - rewritten_score
|
52 |
+
return round(delta, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# detector ์ค๋น
|
55 |
+
base_model_name = "beomi/KcELECTRA-base"
|
56 |
+
detector_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
|
57 |
+
<<<<<<< HEAD
|
58 |
+
num_labels, num_bio_labels, num_targets = 4, 5, 9 # โ ์์ !
|
59 |
+
detector_model = KHatersModelCRF(base_model_name, num_labels, num_bio_labels, num_targets)
|
60 |
+
# ckpt = torch.load("/root/PROJECT-ROOT/backend/kcelectra_crf_ckpt/best_model.pt", map_location="cpu")
|
61 |
+
ckpt_path = hf_hub_download(repo_id="alohaboy/hate_detector_ko", filename="best_model.pt", use_auth_token=True)
|
62 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
63 |
+
=======
|
64 |
+
num_labels, num_bio_labels, num_targets = 5, 3, 9
|
65 |
+
detector_model = KHatersModelCRF(base_model_name, num_labels, num_bio_labels, num_targets)
|
66 |
+
ckpt = torch.load("/root/PROJECT-ROOT/backend/kcelectra_crf_ckpt/best_model.pt", map_location="cpu")
|
67 |
+
>>>>>>> 1094316 (Update Gradio demo for Hugging Face Space)
|
68 |
+
if "model_state_dict" in ckpt:
|
69 |
+
state_dict = ckpt["model_state_dict"]
|
70 |
+
else:
|
71 |
+
state_dict = ckpt
|
72 |
+
# state_dict ํค ๋ณํ ํจ์น (๊ตฌ๋ฒ์ โ ์ ๋ฒ์ )
|
73 |
+
crf_key_map = {
|
74 |
+
"crf.trans_matrix": "crf.transitions",
|
75 |
+
"crf.start_trans": "crf.start_transitions",
|
76 |
+
"crf.end_trans": "crf.end_transitions",
|
77 |
+
}
|
78 |
+
for old_key, new_key in crf_key_map.items():
|
79 |
+
if old_key in state_dict:
|
80 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
81 |
+
<<<<<<< HEAD
|
82 |
+
# ๊ตฌ๋ฒ์ ํค๊ฐ ๋จ์์์ผ๋ฉด ์ญ์
|
83 |
+
for k in ["crf.trans_matrix", "crf.start_trans", "crf.end_trans"]:
|
84 |
+
=======
|
85 |
+
# ์ ๋ฒ์ ํค๊ฐ ๋จ์์์ผ๋ฉด ์ญ์
|
86 |
+
for k in ["crf.transitions", "crf.start_transitions", "crf.end_transitions"]:
|
87 |
+
>>>>>>> 1094316 (Update Gradio demo for Hugging Face Space)
|
88 |
+
if k in state_dict:
|
89 |
+
del state_dict[k]
|
90 |
+
detector_model.load_state_dict(state_dict)
|
91 |
+
# detector_model.to("cuda") # GPU๋ก ์ด๋ (์ ๊ฑฐ)
|
92 |
+
detector_model.eval()
|
93 |
|
94 |
+
# LLM ์ค๋น
|
95 |
+
LLM_MODEL_NAME = "Bllossom/llama-3.2-Korean-Bllossom-3B"
|
96 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, use_auth_token=True)
|
97 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
98 |
+
LLM_MODEL_NAME,
|
99 |
+
torch_dtype=torch.float32, # CPU์์ ๋์
|
100 |
+
device_map="cpu", # CPU ์ฌ์ฉ
|
101 |
+
use_auth_token=True
|
102 |
+
)
|
103 |
+
llm_model.to("cpu")
|
104 |
+
print("llm_model device:", llm_model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
def phrase_replacement(text, spans):
|
107 |
+
# ํ์ง๋ ํ์ค ์คํฌ๋ง ์ํ(์์: "[์ํ]"๋ก ๋์ฒด)
|
108 |
+
new_text = text
|
109 |
+
for span in spans:
|
110 |
+
new_text = new_text.replace(span['text'], "[์ํ]")
|
111 |
+
return new_text
|
112 |
+
|
113 |
+
def unguided_rewrite(text):
|
114 |
+
# LLM์ "์ํ" ํ๋กฌํํธ ์์ด ๋จ์ paraphrase (์์)
|
115 |
+
prompt = f"๋ค์ ๋ฌธ์ฅ์ ๋ ๋ถ๋๋ฝ๊ฒ ๋ฐ๊ฟ์ฃผ์ธ์: {text}"
|
116 |
+
inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device)
|
117 |
+
gen_ids = llm_model.generate(
|
118 |
+
**inputs,
|
119 |
+
do_sample=True,
|
120 |
+
top_k=50,
|
121 |
+
top_p=0.9,
|
122 |
+
num_return_sequences=1,
|
123 |
+
max_new_tokens=64,
|
124 |
+
pad_token_id=llm_tokenizer.pad_token_id,
|
125 |
+
eos_token_id=llm_tokenizer.eos_token_id
|
126 |
+
)
|
127 |
+
return llm_tokenizer.decode(gen_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
def hate_speech_demo(text, strategy):
|
130 |
+
det = detect_spans(text)
|
131 |
+
detector_result = {
|
132 |
+
"์ํ๋ฌธ์ฅ": text,
|
133 |
+
"์ ํด์นดํ
๊ณ ๋ฆฌ": det["category"],
|
134 |
+
"ํ์ง์คํฌ๋ค": det["spans"],
|
135 |
+
"ํ๊ฒ์์ธก": det["targets"]
|
136 |
+
}
|
137 |
+
# ์ ๋ต๋ณ ์ํ
|
138 |
+
if det["category"] == "normal":
|
139 |
+
output = text
|
140 |
+
candidate_info = "-"
|
141 |
+
elif strategy == "Full Rewrite":
|
142 |
+
num_return_sequences = 5
|
143 |
+
current_text = text
|
144 |
+
best_candidate = current_text
|
145 |
+
best_score = 0.0
|
146 |
+
all_candidates = []
|
147 |
+
all_scores = []
|
148 |
+
for _ in range(3): # max_iter=3
|
149 |
+
prompt = (
|
150 |
+
f"๋ฌธ์ฅ์ ์ํํด์ฃผ์ธ์.\n\n"
|
151 |
+
f"- ์๋ฌธ: {current_text}\n"
|
152 |
+
f"- ํ์ค ํํ: {[span['text'] for span in det['spans']]}\n"
|
153 |
+
f"- ํ์ค ์ ํ: {det['category']}\n"
|
154 |
+
f"- ํ๊ฒ ๊ทธ๋ฃน: {det['targets']}\n\n"
|
155 |
+
f"์ ๋ด์ฉ์ ๊ณ ๋ คํ์ฌ, ํ์ค ํํ์ ๋ ํํ๋ก์ด ํํ์ผ๋ก ์ํํด์ฃผ์ธ์."
|
156 |
+
)
|
157 |
+
inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device)
|
158 |
+
gen_ids = llm_model.generate(
|
159 |
+
**inputs,
|
160 |
+
do_sample=True,
|
161 |
+
top_k=50,
|
162 |
+
top_p=0.9,
|
163 |
+
num_return_sequences=num_return_sequences,
|
164 |
+
max_new_tokens=64,
|
165 |
+
pad_token_id=llm_tokenizer.pad_token_id,
|
166 |
+
eos_token_id=llm_tokenizer.eos_token_id
|
167 |
+
)
|
168 |
+
candidates = [llm_tokenizer.decode(g, skip_special_tokens=True) for g in gen_ids]
|
169 |
+
scores = [critic_score(c, detector_model, detector_tokenizer) for c in candidates]
|
170 |
+
all_candidates.extend(candidates)
|
171 |
+
all_scores.extend(scores)
|
172 |
+
best_idx = int(torch.tensor(scores).argmax())
|
173 |
+
if scores[best_idx] >= 0.7:
|
174 |
+
best_candidate = candidates[best_idx]
|
175 |
+
best_score = scores[best_idx]
|
176 |
+
break
|
177 |
+
best_candidate = candidates[best_idx]
|
178 |
+
best_score = scores[best_idx]
|
179 |
+
current_text = best_candidate
|
180 |
+
output = best_candidate
|
181 |
+
candidate_info = "\n".join([f"[{i+1}] {c} (score={s:.3f})" for i, (c, s) in enumerate(zip(all_candidates, all_scores))])
|
182 |
+
elif strategy == "Phrase Replacement":
|
183 |
+
output = phrase_replacement(text, det["spans"])
|
184 |
+
candidate_info = "-"
|
185 |
+
elif strategy == "Unguided":
|
186 |
+
output = unguided_rewrite(text)
|
187 |
+
candidate_info = "-"
|
188 |
+
else:
|
189 |
+
output = text
|
190 |
+
candidate_info = "-"
|
191 |
+
# ์ค์ฝ์ด ๊ณ์ฐ
|
192 |
+
tox_score = calc_toxicity_reduction(text, output, detector_model, detector_tokenizer)
|
193 |
+
ppl_score = calc_ppl(output)
|
194 |
+
bert_score = calc_bertscore(text, output)
|
195 |
+
soft_or_hard = det.get("soft_or_hard", "-")
|
196 |
+
return (
|
197 |
+
det["category"],
|
198 |
+
str(det["spans"]),
|
199 |
+
str(det["targets"]),
|
200 |
+
output,
|
201 |
+
f"Toxicity Reduction: {tox_score}",
|
202 |
+
f"PPL: {ppl_score}",
|
203 |
+
f"BERTScore: {bert_score}",
|
204 |
+
candidate_info,
|
205 |
+
soft_or_hard
|
206 |
+
)
|
207 |
|
208 |
+
with gr.Blocks(theme=gr.themes.Monochrome(primary_hue="blue", secondary_hue="slate")) as demo:
|
209 |
+
gr.HTML("""
|
210 |
+
<style>
|
211 |
+
.modern-input textarea { font-size: 1.1em; border-radius: 8px; }
|
212 |
+
.modern-btn { background: linear-gradient(90deg,#2563eb,#60a5fa); color: white; border-radius: 8px; font-weight: bold; }
|
213 |
+
.modern-label { font-size: 1.05em; font-weight: 600; color: #2563eb; }
|
214 |
+
.modern-badge { background: #e0e7ff; color: #1e293b; border-radius: 6px; padding: 0.2em 0.6em; font-weight: 500; }
|
215 |
+
.modern-output textarea { background: #f1f5f9; border-radius: 8px; font-size: 1.1em; }
|
216 |
+
.modern-score input { color: #0ea5e9; font-weight: bold; font-size: 1.1em; }
|
217 |
+
.modern-candidate textarea { background: #f8fafc; border-radius: 8px; }
|
218 |
+
</style>
|
219 |
+
""")
|
220 |
+
gr.Markdown("""
|
221 |
+
<div style='text-align:center; margin-bottom: 1em;'>
|
222 |
+
<h1 style='color:#2563eb;'>Hate Speech Mitigation Demo</h1>
|
223 |
+
<p style='font-size:1.1em;'>Enter a sentence and select a mitigation strategy to see the results and various scores.</p>
|
224 |
+
</div>
|
225 |
+
""")
|
226 |
+
with gr.Row():
|
227 |
+
lang = gr.Radio(["Korean", "English"], value="Korean", label="Language", container=True)
|
228 |
+
input_box = gr.Textbox(label="Input text", lines=2, value="์ผ, ๋ณ์ ์. ์ด๋ฐ ๊ฒ๋ ๋ชปํด?", elem_classes="modern-input")
|
229 |
+
strategy = gr.Radio([
|
230 |
+
"Guided",
|
231 |
+
"Phrase Replacement",
|
232 |
+
"Unguided"
|
233 |
+
], value="Full Rewrite", label="Mitigation Strategy", container=True)
|
234 |
+
run_btn = gr.Button("Non-Toxic", variant="primary", elem_classes="modern-btn")
|
235 |
+
with gr.Row():
|
236 |
+
out1 = gr.Label(label="Hate type", elem_classes="modern-label")
|
237 |
+
out2 = gr.HighlightedText(label="Detected spans", elem_classes="modern-badge")
|
238 |
+
out3 = gr.Label(label="Target", elem_classes="modern-label")
|
239 |
+
out5 = gr.Number(label="Toxicity Reduction", elem_classes="modern-score")
|
240 |
+
out6 = gr.Number(label="PPL", elem_classes="modern-score")
|
241 |
+
out7 = gr.Number(label="BERTScore", elem_classes="modern-score")
|
242 |
+
with gr.Accordion("Candidate Info", open=False):
|
243 |
+
out8 = gr.Textbox(label="Candidates", lines=3, elem_classes="modern-candidate")
|
244 |
+
out4 = gr.Textbox(label="output)", lines=2, elem_classes="modern-output")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
+
def hate_speech_multilingual(text, strategy, lang):
|
247 |
+
if lang == "Korean":
|
248 |
+
return hate_speech_demo(text, strategy)[:8] # soft/hard ์ ์ธ
|
249 |
+
else:
|
250 |
+
# ์์ด ํ์ค ํ์ง ํ์ดํ๋ผ์ธ ์ค๋น
|
251 |
+
from english_detector import detect_spans as detect_spans_en, EnglishElectraHateDetector
|
252 |
+
from transformers import AutoTokenizer
|
253 |
+
import torch
|
254 |
+
# ๋งคํ ์์ฑ
|
255 |
+
tag2id = {"O": 0, "B-HATE": 1, "I-HATE": 2, "B-OFF": 3, "I-OFF": 4}
|
256 |
+
id2tag = {v: k for k, v in tag2id.items()}
|
257 |
+
sev2id = {"NORMAL": 0, "O1": 1, "O2": 2, "H1": 3, "H2": 4}
|
258 |
+
id2sev = {v: k for k, v in sev2id.items()}
|
259 |
+
# ์นดํ
๊ณ ๋ฆฌ ์์(์ค์ HateXplain ์ ์ฒด ์นดํ
๊ณ ๋ฆฌ ๋ฐ์ ํ์)
|
260 |
+
cat_list = ["African", "Arab", "Asian", "Atheist", "Buddhist", "Christian", "Female", "Hispanic", "Homosexual_gay_or_lesbian", "Immigrant", "Jewish", "Male", "Other_religions", "Physical_disability", "Transgender", "Other"]
|
261 |
+
cat2id = {c: i for i, c in enumerate(cat_list)}
|
262 |
+
id2cat = {v: k for k, v in cat2id.items()}
|
263 |
+
# ๋ชจ๋ธ/ํ ํฌ๋์ด์ ์ค๋น
|
264 |
+
base_model_name = "google/electra-base-discriminator"
|
265 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
266 |
+
model = EnglishElectraHateDetector(base_model_name, num_severity=5, num_bio_labels=5, num_targets=len(cat2id))
|
267 |
+
ckpt = torch.load("/root/PROJECT-ROOT/backend/english_detector_ckpt/best_model.pt", map_location="cuda")
|
268 |
+
model.load_state_dict(ckpt)
|
269 |
+
model.eval()
|
270 |
+
model.to("cuda") # ํ์ง ๋ชจ๋ธ๋ง GPU์ ์ฌ๋ฆผ
|
271 |
+
# ํ์ง ์คํ
|
272 |
+
det = detect_spans_en(text, model, tokenizer, tag2id, id2tag, sev2id, id2sev, cat2id, id2cat, device="cuda")
|
273 |
+
# HighlightedText ๋ณํ: [(text, class)]
|
274 |
+
spans_for_highlight = [(span["text"], span["label"]) for span in det["spans"] if span.get("text")]
|
275 |
+
# ํ๊ฒ ๋ณํ
|
276 |
+
targets_str = ", ".join(det["targets"]) if det["targets"] else "-"
|
277 |
+
# ์ ๋ต๋ณ ์ํ(์ฌ๊ธฐ์ ์๋ฌธ ๊ทธ๋๋ก ๋ฐํ, ํ์์ ์์ด generator ์ฐ๋)
|
278 |
+
output = text
|
279 |
+
candidate_info = "-"
|
280 |
+
# ์ ์ placeholder
|
281 |
+
tox_score = "-"
|
282 |
+
ppl_score = "-"
|
283 |
+
bert_score = "-"
|
284 |
+
soft_or_hard = det.get("soft_or_hard", "-")
|
285 |
+
return (
|
286 |
+
det["category"] if "category" in det else det.get("severity", "-"),
|
287 |
+
spans_for_highlight,
|
288 |
+
targets_str,
|
289 |
+
tox_score,
|
290 |
+
ppl_score,
|
291 |
+
bert_score,
|
292 |
+
candidate_info,
|
293 |
+
output
|
294 |
+
)
|
295 |
|
296 |
+
run_btn.click(
|
297 |
+
hate_speech_multilingual,
|
298 |
+
inputs=[input_box, strategy, lang],
|
299 |
+
outputs=[out1, out2, out3, out5, out6, out7, out8, out4]
|
300 |
+
)
|
301 |
|
302 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
+
print("input_ids:", input_ids)
|
305 |
+
print("attention_mask:", attention_mask)
|
306 |
+
print("decode_mask:", decode_mask)
|
307 |
+
print("bio_feats.shape:", bio_feats.shape)
|