alohaboy commited on
Commit
2d7f457
ยท
1 Parent(s): e8dd03b

Update to latest Korean hate speech detection and mitigation system

Browse files
Files changed (2) hide show
  1. README.md +9 -5
  2. app.py +280 -616
README.md CHANGED
@@ -1,14 +1,18 @@
1
  ---
2
- title: Korean Hate Speech Mitigation Demo
3
- emoji: "๐Ÿ›ก๏ธ"
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: "4.44.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # Korean Hate Speech Mitigation Demo
13
 
14
- ์ด Space๋Š” ํ•œ๊ตญ์–ด ํ˜์˜ค ํ‘œํ˜„ ํƒ์ง€ ๋ฐ ์ˆœํ™” ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
 
 
 
 
 
1
  ---
2
+ title: Hate Speech Mitigation Demo
3
+ emoji: ๐Ÿ›ก๏ธ
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: "4.27.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Hate Speech Mitigation Demo
13
 
14
+ ์ด Space๋Š” ํ•œ๊ตญ์–ด ํ˜์˜ค ํ‘œํ˜„ ์™„ํ™” ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
15
+
16
+ - Gradio ๊ธฐ๋ฐ˜ ์ธํ„ฐํŽ˜์ด์Šค
17
+ - Electra + CRF ๊ธฐ๋ฐ˜ ํ˜์˜ค ํƒ์ง€
18
+ - LLM ๊ธฐ๋ฐ˜ ๋ฌธ์žฅ ์ˆœํ™”
app.py CHANGED
@@ -1,5 +1,13 @@
1
  import gradio as gr
2
  import torch
 
 
 
 
 
 
 
 
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
5
  import numpy as np
@@ -8,636 +16,292 @@ from TorchCRF import CRF
8
 
9
  from bert_score import score as bert_score_fn
10
  import re
11
- from huggingface_hub import hf_hub_download
12
 
 
 
 
13
  def calc_bertscore(orig_text, rewritten_text):
14
  P, R, F1 = bert_score_fn([rewritten_text], [orig_text], lang="ko")
15
  return round(F1[0].item(), 3)
16
 
 
 
 
 
 
17
  def calc_ppl(text):
18
- try:
19
- tokens = text.split()
20
- if len(tokens) < 2:
21
- return 1.0
22
- word_count = len(tokens)
23
- base_ppl = 50.0
24
- length_factor = min(word_count / 10.0, 2.0)
25
- complexity_factor = 1.0 + (len(set(tokens)) / word_count) * 0.5
26
- ppl = base_ppl * length_factor * complexity_factor
27
- return round(ppl, 3)
28
- except Exception as e:
29
- print(f"PPL calculation error: {e}")
30
- return 1.0
31
 
 
32
  def calc_toxicity_reduction(orig_text, rewritten_text, detector_model, detector_tokenizer):
33
- try:
34
- # Original toxicity score
35
- orig_enc = detector_tokenizer(orig_text, return_tensors="pt", padding="max_length", max_length=128)
36
- device = next(detector_model.parameters()).device
37
- orig_input_ids = orig_enc["input_ids"].to(device)
38
- orig_attention_mask = orig_enc["attention_mask"].to(device)
39
- with torch.no_grad():
40
- orig_out = detector_model(input_ids=orig_input_ids, attention_mask=orig_attention_mask)
41
- orig_logits = orig_out["sentence_logits"][0]
42
- orig_probs = torch.softmax(orig_logits, dim=-1)
43
- orig_toxicity = 1.0 - orig_probs[0].item()
44
- # Rewritten toxicity score
45
- rewritten_enc = detector_tokenizer(rewritten_text, return_tensors="pt", padding="max_length", max_length=128)
46
- rewritten_input_ids = rewritten_enc["input_ids"].to(device)
47
- rewritten_attention_mask = rewritten_enc["attention_mask"].to(device)
48
- with torch.no_grad():
49
- rewritten_out = detector_model(input_ids=rewritten_input_ids, attention_mask=rewritten_attention_mask)
50
- rewritten_logits = rewritten_out["sentence_logits"][0]
51
- rewritten_probs = torch.softmax(rewritten_logits, dim=-1)
52
- rewritten_toxicity = 1.0 - rewritten_probs[0].item()
53
- delta = orig_toxicity - rewritten_toxicity
54
- return round(delta, 3)
55
- except Exception as e:
56
- print(f"Toxicity reduction calculation error: {e}")
57
- return 0.0
58
 
59
- class HateSpeechDetector(nn.Module):
60
- def __init__(self, model_name="beomi/KcELECTRA-base", num_sentence_labels=4, num_bio_labels=5, num_targets=9):
61
- super().__init__()
62
- self.config = AutoConfig.from_pretrained(model_name)
63
- self.encoder = AutoModel.from_pretrained(model_name, config=self.config)
64
- hidden_size = self.config.hidden_size
65
- self.dropout = nn.Dropout(0.1)
66
- self.classifier = nn.Linear(hidden_size, num_sentence_labels) # Sentence classification
67
- self.bio_linear = nn.Linear(hidden_size, num_bio_labels) # BIO tagging
68
- self.crf = CRF(num_bio_labels)
69
- self.target_head = nn.Linear(hidden_size, num_targets) # Target classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- def forward(self, input_ids, attention_mask, bio_tags=None, sentence_labels=None, targets=None):
72
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
73
- sequence_output = outputs.last_hidden_state
74
- pooled_output = sequence_output[:, 0, :]
75
- dropped = self.dropout(pooled_output)
76
- sentence_logits = self.classifier(dropped)
77
- bio_feats = self.bio_linear(sequence_output)
78
- bio_loss = None
79
- if bio_tags is not None:
80
- mask = bio_tags != -100
81
- log_likelihood = self.crf.forward(bio_feats, bio_tags, mask=mask)
82
- bio_loss = -log_likelihood
83
- tgt_dropped = self.dropout(pooled_output)
84
- target_logits = self.target_head(tgt_dropped)
85
- loss = 0.0
86
- if sentence_labels is not None:
87
- cls_loss = nn.CrossEntropyLoss()(sentence_logits, sentence_labels)
88
- loss += cls_loss
89
- if bio_loss is not None:
90
- loss += bio_loss.sum()
91
- if targets is not None:
92
- bce_loss = nn.BCEWithLogitsLoss()(target_logits, targets)
93
- loss += 2.0 * bce_loss
94
- # CRF decode
95
- if bio_tags is not None:
96
- decode_mask = bio_tags != -100
97
- else:
98
- decode_mask = attention_mask.bool()
99
- print("[DEBUG] bio_tags:", bio_tags)
100
- print("[DEBUG] attention_mask.shape:", attention_mask.shape)
101
- print("[DEBUG] decode_mask.shape:", decode_mask.shape)
102
- print("[DEBUG] decode_mask[:, 0]:", decode_mask[:, 0] if decode_mask.dim() > 1 else decode_mask[0])
103
- print("[DEBUG] bio_feats.shape:", bio_feats.shape)
104
- bio_preds = self.crf.viterbi_decode(bio_feats, mask=decode_mask)
105
- return {
106
- 'loss': loss,
107
- 'sentence_logits': sentence_logits,
108
- 'bio_logits': bio_feats,
109
- 'bio_preds': bio_preds,
110
- 'target_logits': target_logits
111
- }
112
 
113
- class HateSpeechDetectorService:
114
- def __init__(self):
115
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
- self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base")
117
- self.model = HateSpeechDetector()
118
-
119
- # Model loading
120
- MODEL_CKPT_PATH = hf_hub_download(repo_id="alohaboy/hate_detector_ko", filename="best_model.pt")
121
- checkpoint = torch.load(MODEL_CKPT_PATH, map_location=self.device)
122
- # state_dict key conversion
123
- key_map = {
124
- 'sentence_classifier.weight': 'classifier.weight',
125
- 'sentence_classifier.bias': 'classifier.bias',
126
- 'bio_classifier.weight': 'bio_linear.weight',
127
- 'bio_classifier.bias': 'bio_linear.bias',
128
- # CRF related keys (reverse)
129
- 'crf.transitions': 'crf.trans_matrix',
130
- 'crf.start_transitions': 'crf.start_trans',
131
- 'crf.end_transitions': 'crf.end_trans',
132
- }
133
- new_state_dict = {}
134
- # If checkpoint is a dict and model_state_dict key exists, load from it
135
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
136
- state_dict = checkpoint['model_state_dict']
137
- else:
138
- state_dict = checkpoint
139
- for k, v in state_dict.items():
140
- new_key = key_map.get(k, k)
141
- new_state_dict[new_key] = v
142
- self.model.load_state_dict(new_state_dict, strict=True)
143
- self.model.to(self.device)
144
- self.model.eval()
145
-
146
- # Blossom LLM loading
147
- print("Blossom LLM loading...")
148
- self.llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-3B"
149
- self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
150
- self.llm_model = AutoModelForCausalLM.from_pretrained(
151
- self.llm_model_name,
152
- torch_dtype=torch.bfloat16,
153
- device_map="auto"
154
- )
155
- print("LLM loading complete!")
156
-
157
- self.label_names = ["normal", "offensive", "L1_hate", "L2_hate"]
158
- self.bio_names = {0: "O", 1: "B-SOFT", 2: "I-SOFT", 3: "B-HARD", 4: "I-HARD"}
159
-
160
- val_acc = checkpoint['val_acc'] if 'val_acc' in checkpoint else None
161
- if val_acc is not None:
162
- print(f"Model loaded - Validation accuracy: {val_acc:.2f}%")
163
- else:
164
- print("Model loaded - Validation accuracy: N/A")
165
-
166
- def detect_hate_speech(self, text, strategy="Detection Only"):
167
- """Hate Speech Detection and Mitigation"""
168
- if not text.strip():
169
- return "Please enter text", ""
170
- if len(text.strip()) < 2:
171
- return "Input text is too short. Please enter at least 2 characters.", ""
172
-
173
- if strategy == "Detection Only":
174
- result_msg, mitigation, debug_info = self._detection_only(text)
175
- print("[DEBUG] Input text:", text)
176
- print("[DEBUG] sentence_logits:", debug_info.get('sentence_logits'))
177
- print("[DEBUG] sentence_probs:", debug_info.get('sentence_probs'))
178
- print("[DEBUG] sentence_pred:", debug_info.get('sentence_pred'))
179
- print("[DEBUG] label:", debug_info.get('label'))
180
- print("[DEBUG] confidence:", debug_info.get('confidence'))
181
- return result_msg, mitigation
182
- elif strategy == "Guided":
183
- return self._guided_mitigation(text)
184
- elif strategy == "Guided+Reflect":
185
- return self._guided_reflect_mitigation(text)
186
- elif strategy == "Unguided":
187
- return self._unguided_mitigation(text)
188
- else:
189
- return "Invalid strategy", ""
190
-
191
- def _detection_only(self, text):
192
- """Perform only detection (existing logic)"""
193
- # Tokenization
194
- encoding = self.tokenizer(
195
- text,
196
- truncation=True,
197
- padding="max_length",
198
- max_length=128,
199
- return_attention_mask=True,
200
- return_tensors="pt"
201
- )
202
-
203
- input_ids = encoding["input_ids"].to(self.device)
204
- attention_mask = encoding["attention_mask"].to(self.device)
205
- print("[DEBUG] attention_mask[:, 0] =", attention_mask[:, 0])
206
-
207
- # Prediction
208
- with torch.no_grad():
209
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
210
- sentence_logits = outputs["sentence_logits"]
211
- bio_logits = outputs["bio_logits"]
212
-
213
- # Sentence classification result
214
- sentence_probs = torch.softmax(sentence_logits, dim=1)
215
- sentence_pred = torch.argmax(sentence_logits, dim=1).item()
216
- sentence_prob = sentence_probs[0][sentence_pred].item()
217
-
218
- # BIO tagging result
219
- bio_preds = torch.argmax(bio_logits, dim=2)[0]
220
-
221
- # Find hate/aggressive tokens
222
- hate_tokens = []
223
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
224
-
225
- # Tokenize original text to get offset mapping
226
- tokenized = self.tokenizer(
227
- text,
228
- truncation=True,
229
- padding="max_length",
230
- max_length=128,
231
- return_offsets_mapping=True
232
- )
233
- offset_mapping = tokenized["offset_mapping"]
234
-
235
- for j, (token, pred) in enumerate(zip(tokens, bio_preds)):
236
- if pred.item() != 0: # Not O
237
- # Extract the corresponding part from the original text using offset mapping
238
- if j < len(offset_mapping):
239
- start, end = offset_mapping[j]
240
- if start != end: # Token mapped to actual text
241
- original_text = text[start:end]
242
- hate_tokens.append((j, original_text, self.bio_names[pred.item()]))
243
- else:
244
- # Special token handling
245
- if token.startswith('ฤ '):
246
- decoded_token = token[1:] # Remove ฤ 
247
- elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
248
- decoded_token = token
249
- else:
250
- decoded_token = token
251
- hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
252
- else:
253
- # Fallback
254
- if token.startswith('ฤ '):
255
- decoded_token = token[1:]
256
- elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
257
- decoded_token = token
258
- else:
259
- decoded_token = token
260
- hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
261
-
262
- # Determine label
263
- label = self.label_names[sentence_pred]
264
- # If hate_tokens contain B-HARD, I-HARD, increase label to L2_hate
265
- if any(bio_label in ["B-HARD", "I-HARD"] for _, _, bio_label in hate_tokens):
266
- label = "L2_hate"
267
- # Construct result message
268
- result_msg = f"Detection result: {label}\nConfidence: {sentence_prob:.2f}"
269
- if hate_tokens:
270
- result_msg += f"\nIdentified hate/aggressive expressions: {hate_tokens}"
271
- mitigation = "Performed only detection."
272
- debug_info = {
273
- 'sentence_logits': sentence_logits,
274
- 'sentence_probs': sentence_probs,
275
- 'sentence_pred': sentence_pred,
276
- 'label': label,
277
- 'confidence': sentence_prob,
278
- 'hate_tokens': hate_tokens
279
- }
280
- return result_msg, mitigation, debug_info
281
-
282
- def _unguided_mitigation(self, text):
283
- """Unguided Mode: Only Using Generation"""
284
- try:
285
- # Blossom LLM prompt
286
- prompt = f"""Please remove hate speech or aggressive expressions from the following sentence, while maintaining the original intent (criticism, complaint, opinion, etc.).
287
 
288
- Original: {text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- Mitigated sentence:"""
291
-
292
- # LLM inference
293
- inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
294
-
295
- with torch.no_grad():
296
- outputs = self.llm_model.generate(
297
- **inputs,
298
- do_sample=True,
299
- top_k=50,
300
- top_p=0.9,
301
- max_new_tokens=300,
302
- pad_token_id=self.llm_tokenizer.pad_token_id,
303
- eos_token_id=self.llm_tokenizer.eos_token_id
304
- )
305
-
306
- # Decode result
307
- full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
308
-
309
- # Remove prompt part and extract mitigated sentence
310
- mitigated_text = full_response.replace(prompt, "").strip()
311
-
312
- # Handle truncated sentences
313
- if len(mitigated_text) < 10: # Too short, use original response
314
- mitigated_text = full_response
315
-
316
- # Prevent repetitive output: extract only the first mitigated sentence
317
- if "Mitigated sentence:" in mitigated_text:
318
- mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
319
-
320
- # Use only the first meaningful line if multiple lines
321
- lines = mitigated_text.split('\n')
322
- clean_lines = []
323
- for line in lines:
324
- line = line.strip()
325
- if line and not line.startswith('**') and not line.startswith('Original:'):
326
- clean_lines.append(line)
327
-
328
- if clean_lines:
329
- mitigated_text = clean_lines[0]
330
-
331
- # Result message
332
- result_msg = f"๐Ÿค– **Blossom LLM Mitigation Result**\n\n"
333
- result_msg += f"**Original:** {text}\n\n"
334
- result_msg += f"**Mitigated Sentence:** {mitigated_text}"
335
-
336
- # Mitigation info
337
- mitigation = "**Unguided Mode:** Blossom LLM detected and mitigated harmful expressions autonomously."
338
-
339
- return result_msg, mitigation
340
-
341
- except Exception as e:
342
- error_msg = f"โŒ **Blossom LLM Error**\n\nError occurred: {str(e)}"
343
- return error_msg, "An error occurred during LLM processing."
344
-
345
- def _guided_mitigation(self, text):
346
- """Guided Mode: Mitigate based on KcELECTRA detection result using Blossom LLM"""
347
- try:
348
- # First, perform detection with KcELECTRA
349
- detection_result, _, debug_info = self._detection_only(text)
350
- label = debug_info.get('label', 'normal')
351
- hate_tokens = debug_info.get('hate_tokens', [])
352
-
353
- # Construct Blossom LLM prompt
354
- if label == "normal":
355
- prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
356
- else:
357
- label_desc = {
358
- "offensive": "Aggressive",
359
- "L1_hate": "Mild Hate",
360
- "L2_hate": "Severe Hate"
361
- }
362
- hate_tokens_str = ""
363
- if hate_tokens:
364
- hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"โ€ข {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
365
- prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nPlease remove hate speech or aggressive expressions, while maintaining the original intent (criticism, complaint, opinion, etc.).\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\n[Important] All offensive, derogatory, and explicit hate expressions (e.g., ์”จ๋ฐœ, ์ข†, ๋ณ‘์‹ ) must be deleted.\n\nMitigated sentence:"""
366
- # LLM inference
367
- inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
368
- with torch.no_grad():
369
- outputs = self.llm_model.generate(
370
- **inputs,
371
- do_sample=True,
372
- top_k=50,
373
- top_p=0.9,
374
- max_new_tokens=300,
375
- pad_token_id=self.llm_tokenizer.pad_token_id,
376
- eos_token_id=self.llm_tokenizer.eos_token_id
377
- )
378
- full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
379
- mitigated_text = full_response.replace(prompt, "").strip()
380
- if len(mitigated_text) < 10:
381
- mitigated_text = full_response
382
- if "Mitigated sentence:" in mitigated_text:
383
- mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
384
- lines = mitigated_text.split('\n')
385
- clean_lines = []
386
- for line in lines:
387
- line = line.strip()
388
- if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
389
- clean_lines.append(line)
390
- if clean_lines:
391
- mitigated_text = clean_lines[0]
392
- result_msg = f"๐ŸŽฏ **Guided Mitigation Result**\n\n"
393
- result_msg += f"**KcELECTRA Detection Result:**\n{detection_result}\n\n"
394
- result_msg += f"**Blossom LLM Mitigation Result:**\n{mitigated_text}"
395
- mitigation = "**Guided Mode:** Blossom LLM performed specific mitigation based on KcELECTRA's detection information."
396
- return result_msg, mitigation
397
- except Exception as e:
398
- error_msg = f"โŒ **Guided Mitigation Error**\n\nError occurred: {str(e)}"
399
- return error_msg, "An error occurred during guided mitigation processing."
400
-
401
- def _guided_reflect_mitigation(self, text):
402
- """Guided+Reflect Mode: iterative refinement + critic evaluation"""
403
- try:
404
- detection_result, _, debug_info = self._detection_only(text)
405
- label = debug_info.get('label', 'normal')
406
- hate_tokens = debug_info.get('hate_tokens', [])
407
- # Step 1: Initial mitigation
408
- if label == "normal":
409
- initial_prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
410
- else:
411
- label_desc = {
412
- "offensive": "Aggressive",
413
- "L1_hate": "Mild Hate",
414
- "L2_hate": "Severe Hate"
415
- }
416
- hate_tokens_str = ""
417
- if hate_tokens:
418
- hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"โ€ข {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
419
- initial_prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nExpressions containing offensive words (e.g., ์ขƒ, ์”จ๋ฐœ, ๋ณ‘์‹ ) must be deleted.\nOther aggressive or inappropriate expressions should be mitigated by expressing them more politely and inclusively.\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\nMitigated sentence:"""
420
- # Iterative mitigation and evaluation
421
- max_iter = 5
422
- metrics_history = []
423
- best_candidate = None
424
- best_score = -float('inf')
425
- current_input = text
426
- for i in range(max_iter):
427
- # Generate candidate
428
- inputs = self.llm_tokenizer(initial_prompt, return_tensors="pt").to(self.llm_model.device)
429
- with torch.no_grad():
430
- outputs = self.llm_model.generate(
431
- **inputs,
432
- do_sample=True,
433
- top_k=50,
434
- top_p=0.9,
435
- max_new_tokens=300,
436
- pad_token_id=self.llm_tokenizer.pad_token_id,
437
- eos_token_id=self.llm_tokenizer.eos_token_id
438
- )
439
- candidate = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
440
- mitigated_text = candidate.replace(initial_prompt, "").strip()
441
- if len(mitigated_text) < 10:
442
- mitigated_text = candidate
443
- if "Mitigated sentence:" in mitigated_text:
444
- mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
445
- lines = mitigated_text.split('\n')
446
- clean_lines = []
447
- for line in lines:
448
- line = line.strip()
449
- if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
450
- clean_lines.append(line)
451
- if clean_lines:
452
- mitigated_text = clean_lines[0]
453
- # Exclude candidates containing offensive words
454
- if contains_badword(mitigated_text):
455
- continue
456
- # Evaluation
457
- toxicity = calc_toxicity_reduction(text, mitigated_text, self.model, self.tokenizer)
458
- bertscore = calc_bertscore(text, mitigated_text)
459
- ppl = calc_ppl(mitigated_text)
460
- metrics_history.append({'iteration': i+1, 'candidate': mitigated_text, 'toxicity': toxicity, 'bertscore': bertscore, 'ppl': ppl})
461
- # Simple combined score (weight adjustment possible)
462
- total_score = toxicity + bertscore - ppl * 0.01
463
- if total_score > best_score:
464
- best_score = total_score
465
- best_candidate = mitigated_text
466
- # Early termination criteria (e.g., toxicity>0.3, bertscore>0.7, ppl<100)
467
- if toxicity > 0.3 and bertscore > 0.7 and ppl < 100:
468
- break
469
- # Log output
470
- iter_log_str = ""
471
- for log in metrics_history:
472
- iter_log_str += f"\nIteration {log['iteration']}:\n- Candidate: {log['candidate']}\n- Toxicity reduction: {log['toxicity']}, bertscore: {log['bertscore']}, ppl: {log['ppl']}"
473
- # Result message
474
- result_msg = f"๐Ÿ”„ **Guided+Reflect Mitigation Result**\n\n"
475
- result_msg += f"**Detection Result:**\n{detection_result}\n\n"
476
- result_msg += f"**Iterative Mitigation Log:**{iter_log_str}\n\n"
477
- result_msg += f"**Best Mitigation:** {best_candidate}"
478
- mitigation = "**Guided+Reflect Mode:** Selected the optimal candidate after iterative mitigation and evaluation (maximum 5 iterations)."
479
- return result_msg, mitigation
480
- except Exception as e:
481
- error_msg = f"โŒ **Guided+Reflect Mitigation Error**\n\nError occurred: {str(e)}"
482
- return error_msg, "An error occurred during guided+reflect mitigation processing."
483
-
484
- def _suggest_mitigation(self, label, confidence, hate_tokens):
485
- """Suggest mitigation for hate speech expressions"""
486
- if label == "normal":
487
- return "โœ… **Mitigation Suggestion**: This sentence does not require correction."
488
-
489
- mitigation = f"**๐Ÿ”ง Mitigation Suggestion for Hate Speech:**\n\n"
490
-
491
- if label == "offensive":
492
- mitigation += "**Aggressive Expression Mitigation Options:**\n"
493
- mitigation += "โ€ข Try to change aggressive expressions to more polite expressions\n"
494
- mitigation += "โ€ข Use objective expressions instead of emotional expressions\n"
495
- mitigation += "โ€ข Reconstruct with a mind to be considerate\n"
496
- mitigation += "โ€ข When criticizing, provide specific and constructive feedback"
497
- elif label == "L1_hate":
498
- mitigation += "**Implicit Hate Expression Mitigation Options:**\n"
499
- mitigation += "โ€ข Remove expressions that discriminate or show prejudice\n"
500
- mitigation += "โ€ข Avoid generalizing about specific groups\n"
501
- mitigation += "โ€ข Use more inclusive and respectful expressions\n"
502
- mitigation += "โ€ข Change to expressions that acknowledge diversity"
503
- else: # L2_hate
504
- mitigation += "**Explicit Hate Expression Mitigation Options:**\n"
505
- mitigation += "โ€ข Completely remove severe hate expressions\n"
506
- mitigation += "โ€ข Do not use violent or threatening expressions\n"
507
- mitigation += "โ€ข Use expressions that respect everyone's dignity\n"
508
- mitigation += "โ€ข Change to expressions that discriminate or promote hate\n"
509
- mitigation += "โ€ข If necessary, seek professional help"
510
-
511
- return mitigation
512
 
513
- def contains_badword(text):
514
- badwords = ["์ขƒ", "์”จ๋ฐœ", "๋ณ‘์‹ ", "๊ฐœ์ƒˆ๋ผ", "์—ผ๋ณ‘", "์ข†", "ใ……ใ…‚", "ใ…„", "ใ…‚ใ……", "ใ…—", "ใ…‰"]
515
- return any(bad in text for bad in badwords)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
- # Service initialization
518
- service = HateSpeechDetectorService()
 
 
 
519
 
520
- # Gradio interface
521
- def create_demo():
522
- with gr.Blocks(
523
- title="Korean Hate Speech Detection and Mitigation System",
524
- theme=gr.themes.Soft(),
525
- css="""
526
- .gradio-container {
527
- max-width: 800px;
528
- margin: 0 auto;
529
- }
530
- .result-box {
531
- border-radius: 10px;
532
- padding: 15px;
533
- margin: 10px 0;
534
- }
535
- .normal { background-color: #d4edda; border: 1px solid #c3e6cb; }
536
- .offensive { background-color: #fff3cd; border: 1px solid #ffeaa7; }
537
- .hate { background-color: #f8d7da; border: 1px solid #f5c6cb; }
538
- """
539
- ) as demo:
540
- gr.Markdown("""
541
- # Korean Hate Speech Detection and Mitigation System
542
-
543
- This system detects hate speech in Korean text and provides mitigation suggestions.
544
-
545
-
546
- **๐ŸŸข Normal**:
547
- - It is a normal sentence.
548
-
549
- **๐ŸŸก Offensive**
550
-
551
- - For example: "Don't say such a stupid thing", "How can you do such a stupid thing"
552
-
553
- **๐ŸŸ  L1_hate (Implicit Hate)**: Mild hate expression
554
- - **Implicit hate expression** for protected attribute groups
555
- - For example: "Those people are all the same", "Prejudicial expression towards a specific group"
556
-
557
- **๐Ÿ”ด L2_hate (Explicit Hate)**: Severe hate expression
558
- - **Explicit hate expression** for protected attribute groups
559
-
560
- **๐Ÿค– Mitigation Mode:**
561
- - ๐Ÿ” **Detection Only**: Hate Speech Detection Only
562
- - ๐ŸŽฏ **Guided**: Guided Mitigation
563
- - ๐Ÿ”„ **Guided+Reflect**: After Guided Mitigation, Iterative Refinement
564
- - ๐Ÿค– **Unguided**: LLM generates text without any guidance
565
- """)
566
-
567
- with gr.Row():
568
- with gr.Column(scale=2):
569
- input_text = gr.Textbox(
570
- label="Enter text to detect hate speech & mitigate",
571
- lines=3
572
- )
573
-
574
- strategy = gr.Radio(
575
- ["Detection Only", "Guided", "Guided+Reflect", "Unguided"],
576
- value="Detection Only",
577
- label="Select Mitigation Mode",
578
- container=True
579
- )
580
-
581
- analyze_btn = gr.Button("๐Ÿ” Detect & Mitigate", variant="primary", size="lg")
582
-
583
- # with gr.Column(scale=1):
584
- # gr.Markdown("""
585
- # **๐Ÿงช Test Examples:**
586
-
587
- # **๐ŸŸข Normal:**
588
- # - "Hello, today's weather is nice."
589
- # - "This movie was really fun."
590
-
591
- # **๐ŸŸก Offensive:**
592
- # - "How can you do such a stupid thing"
593
- # - "Don't say such a stupid thing"
594
-
595
- # **๐ŸŸ  L1_hate (Implicit):**
596
- # - "Those people are all the same"
597
- # - "Prejudicial expression towards a specific group"
598
-
599
- # **๐Ÿ”ด L2_hate (Explicit):**
600
- # - "All women are useless"
601
- # - "People with disabilities are a burden to society"
602
- # """)
603
-
604
- with gr.Row():
605
- with gr.Column():
606
- result_output = gr.Markdown(
607
- label="Mitigation Button",
608
- value="Input text and click the above button."
609
- )
610
-
611
- with gr.Column():
612
- mitigation_output = gr.Markdown(
613
- label="Mitigation Suggestion",
614
- value="Based on the analysis result, mitigation suggestions will be provided."
615
- )
616
-
617
- # Event handlers
618
- analyze_btn.click(
619
- fn=service.detect_hate_speech,
620
- inputs=[input_text, strategy],
621
- outputs=[result_output, mitigation_output]
622
- )
623
-
624
- # Allow analysis via Enter key
625
- input_text.submit(
626
- fn=service.detect_hate_speech,
627
- inputs=[input_text, strategy],
628
- outputs=[result_output, mitigation_output]
629
- )
630
-
631
- # gr.Markdown("""
632
- # ---
633
- # **Model Information:**
634
- # - Detection Model: KcELECTRA-base (Validation Accuracy: 67.67%)
635
- # - Mitigation Model: Blossom LLM (llama-3.2-Korean-Bllossom-3B)
636
- # - Training Data: K-HATERS Dataset
637
- # """)
638
-
639
- return demo
640
 
641
- if __name__ == "__main__":
642
- demo = create_demo()
643
- demo.launch(show_error=True)
 
 
1
  import gradio as gr
2
  import torch
3
+ <<<<<<< HEAD
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from detector import detect_spans, KHatersModelCRF
6
+ from generator import critic_score, mitigate_with_strategy
7
+ from bert_score import score
8
+ from huggingface_hub import hf_hub_download # ์ถ”๊ฐ€
9
+
10
+ =======
11
  import torch.nn as nn
12
  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
13
  import numpy as np
 
16
 
17
  from bert_score import score as bert_score_fn
18
  import re
 
19
 
20
+ >>>>>>> 1094316 (Update Gradio demo for Hugging Face Space)
21
+ # --- ์Šค์ฝ”์–ด ๊ณ„์‚ฐ์šฉ placeholder ํ•จ์ˆ˜๋“ค ---
22
+ # 1. BERTScore (์˜๋ฏธ ์œ ์‚ฌ๋„)
23
  def calc_bertscore(orig_text, rewritten_text):
24
  P, R, F1 = bert_score_fn([rewritten_text], [orig_text], lang="ko")
25
  return round(F1[0].item(), 3)
26
 
27
+ # 2. PPL (Perplexity, ์–ธ์–ด์  ์ž์—ฐ์Šค๋Ÿฌ์›€)
28
+ # KoGPT2 ๋“ฑ ํ•œ๊ตญ์–ด causal LM ์‚ฌ์šฉ (์ตœ์ดˆ 1ํšŒ๋งŒ ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ)
29
+ _kogpt2_tokenizer = None
30
+ _kogpt2_model = None
31
+
32
  def calc_ppl(text):
33
+ global _kogpt2_tokenizer, _kogpt2_model
34
+ if _kogpt2_tokenizer is None or _kogpt2_model is None:
35
+ _kogpt2_tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2")
36
+ _kogpt2_model = AutoModelForCausalLM.from_pretrained("skt/kogpt2-base-v2")
37
+ _kogpt2_model.eval()
38
+ encodings = _kogpt2_tokenizer(text, return_tensors="pt")
39
+ input_ids = encodings.input_ids
40
+ with torch.no_grad():
41
+ outputs = _kogpt2_model(input_ids, labels=input_ids)
42
+ loss = outputs.loss
43
+ ppl = torch.exp(loss).item()
44
+ return round(ppl, 3)
 
45
 
46
+ # 3. ฮ”Tox (์œ ํ•ด์„ฑ ๊ฐ์†Œ)
47
  def calc_toxicity_reduction(orig_text, rewritten_text, detector_model, detector_tokenizer):
48
+ # critic_score ํ•จ์ˆ˜๋กœ normal softmax score ํ‰๊ฐ€
49
+ orig_score = critic_score(orig_text, detector_model, detector_tokenizer)
50
+ rewritten_score = critic_score(rewritten_text, detector_model, detector_tokenizer)
51
+ delta = orig_score - rewritten_score
52
+ return round(delta, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # detector ์ค€๋น„
55
+ base_model_name = "beomi/KcELECTRA-base"
56
+ detector_tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
57
+ <<<<<<< HEAD
58
+ num_labels, num_bio_labels, num_targets = 4, 5, 9 # โ† ์ˆ˜์ •!
59
+ detector_model = KHatersModelCRF(base_model_name, num_labels, num_bio_labels, num_targets)
60
+ # ckpt = torch.load("/root/PROJECT-ROOT/backend/kcelectra_crf_ckpt/best_model.pt", map_location="cpu")
61
+ ckpt_path = hf_hub_download(repo_id="alohaboy/hate_detector_ko", filename="best_model.pt", use_auth_token=True)
62
+ ckpt = torch.load(ckpt_path, map_location="cpu")
63
+ =======
64
+ num_labels, num_bio_labels, num_targets = 5, 3, 9
65
+ detector_model = KHatersModelCRF(base_model_name, num_labels, num_bio_labels, num_targets)
66
+ ckpt = torch.load("/root/PROJECT-ROOT/backend/kcelectra_crf_ckpt/best_model.pt", map_location="cpu")
67
+ >>>>>>> 1094316 (Update Gradio demo for Hugging Face Space)
68
+ if "model_state_dict" in ckpt:
69
+ state_dict = ckpt["model_state_dict"]
70
+ else:
71
+ state_dict = ckpt
72
+ # state_dict ํ‚ค ๋ณ€ํ™˜ ํŒจ์น˜ (๊ตฌ๋ฒ„์ „ โ†’ ์‹ ๋ฒ„์ „)
73
+ crf_key_map = {
74
+ "crf.trans_matrix": "crf.transitions",
75
+ "crf.start_trans": "crf.start_transitions",
76
+ "crf.end_trans": "crf.end_transitions",
77
+ }
78
+ for old_key, new_key in crf_key_map.items():
79
+ if old_key in state_dict:
80
+ state_dict[new_key] = state_dict.pop(old_key)
81
+ <<<<<<< HEAD
82
+ # ๊ตฌ๋ฒ„์ „ ํ‚ค๊ฐ€ ๋‚จ์•„์žˆ์œผ๋ฉด ์‚ญ์ œ
83
+ for k in ["crf.trans_matrix", "crf.start_trans", "crf.end_trans"]:
84
+ =======
85
+ # ์‹ ๋ฒ„์ „ ํ‚ค๊ฐ€ ๋‚จ์•„์žˆ์œผ๋ฉด ์‚ญ์ œ
86
+ for k in ["crf.transitions", "crf.start_transitions", "crf.end_transitions"]:
87
+ >>>>>>> 1094316 (Update Gradio demo for Hugging Face Space)
88
+ if k in state_dict:
89
+ del state_dict[k]
90
+ detector_model.load_state_dict(state_dict)
91
+ # detector_model.to("cuda") # GPU๋กœ ์ด๋™ (์ œ๊ฑฐ)
92
+ detector_model.eval()
93
 
94
+ # LLM ์ค€๋น„
95
+ LLM_MODEL_NAME = "Bllossom/llama-3.2-Korean-Bllossom-3B"
96
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, use_auth_token=True)
97
+ llm_model = AutoModelForCausalLM.from_pretrained(
98
+ LLM_MODEL_NAME,
99
+ torch_dtype=torch.float32, # CPU์—์„œ ๋™์ž‘
100
+ device_map="cpu", # CPU ์‚ฌ์šฉ
101
+ use_auth_token=True
102
+ )
103
+ llm_model.to("cpu")
104
+ print("llm_model device:", llm_model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ def phrase_replacement(text, spans):
107
+ # ํƒ์ง€๋œ ํ˜์˜ค ์ŠคํŒฌ๋งŒ ์ˆœํ™”(์˜ˆ์‹œ: "[์ˆœํ™”]"๋กœ ๋Œ€์ฒด)
108
+ new_text = text
109
+ for span in spans:
110
+ new_text = new_text.replace(span['text'], "[์ˆœํ™”]")
111
+ return new_text
112
+
113
+ def unguided_rewrite(text):
114
+ # LLM์— "์ˆœํ™”" ํ”„๋กฌํ”„ํŠธ ์—†์ด ๋‹จ์ˆœ paraphrase (์˜ˆ์‹œ)
115
+ prompt = f"๋‹ค์Œ ๋ฌธ์žฅ์„ ๋” ๋ถ€๋“œ๋Ÿฝ๊ฒŒ ๋ฐ”๊ฟ”์ฃผ์„ธ์š”: {text}"
116
+ inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device)
117
+ gen_ids = llm_model.generate(
118
+ **inputs,
119
+ do_sample=True,
120
+ top_k=50,
121
+ top_p=0.9,
122
+ num_return_sequences=1,
123
+ max_new_tokens=64,
124
+ pad_token_id=llm_tokenizer.pad_token_id,
125
+ eos_token_id=llm_tokenizer.eos_token_id
126
+ )
127
+ return llm_tokenizer.decode(gen_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ def hate_speech_demo(text, strategy):
130
+ det = detect_spans(text)
131
+ detector_result = {
132
+ "์ƒ˜ํ”Œ๋ฌธ์žฅ": text,
133
+ "์œ ํ•ด์นดํ…Œ๊ณ ๋ฆฌ": det["category"],
134
+ "ํƒ์ง€์ŠคํŒฌ๋“ค": det["spans"],
135
+ "ํƒ€๊ฒŸ์˜ˆ์ธก": det["targets"]
136
+ }
137
+ # ์ „๋žต๋ณ„ ์ˆœํ™”
138
+ if det["category"] == "normal":
139
+ output = text
140
+ candidate_info = "-"
141
+ elif strategy == "Full Rewrite":
142
+ num_return_sequences = 5
143
+ current_text = text
144
+ best_candidate = current_text
145
+ best_score = 0.0
146
+ all_candidates = []
147
+ all_scores = []
148
+ for _ in range(3): # max_iter=3
149
+ prompt = (
150
+ f"๋ฌธ์žฅ์„ ์ˆœํ™”ํ•ด์ฃผ์„ธ์š”.\n\n"
151
+ f"- ์›๋ฌธ: {current_text}\n"
152
+ f"- ํ˜์˜ค ํ‘œํ˜„: {[span['text'] for span in det['spans']]}\n"
153
+ f"- ํ˜์˜ค ์œ ํ˜•: {det['category']}\n"
154
+ f"- ํƒ€๊ฒŸ ๊ทธ๋ฃน: {det['targets']}\n\n"
155
+ f"์œ„ ๋‚ด์šฉ์„ ๊ณ ๋ คํ•˜์—ฌ, ํ˜์˜ค ํ‘œํ˜„์„ ๋” ํ‰ํ™”๋กœ์šด ํ‘œํ˜„์œผ๋กœ ์ˆœํ™”ํ•ด์ฃผ์„ธ์š”."
156
+ )
157
+ inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device)
158
+ gen_ids = llm_model.generate(
159
+ **inputs,
160
+ do_sample=True,
161
+ top_k=50,
162
+ top_p=0.9,
163
+ num_return_sequences=num_return_sequences,
164
+ max_new_tokens=64,
165
+ pad_token_id=llm_tokenizer.pad_token_id,
166
+ eos_token_id=llm_tokenizer.eos_token_id
167
+ )
168
+ candidates = [llm_tokenizer.decode(g, skip_special_tokens=True) for g in gen_ids]
169
+ scores = [critic_score(c, detector_model, detector_tokenizer) for c in candidates]
170
+ all_candidates.extend(candidates)
171
+ all_scores.extend(scores)
172
+ best_idx = int(torch.tensor(scores).argmax())
173
+ if scores[best_idx] >= 0.7:
174
+ best_candidate = candidates[best_idx]
175
+ best_score = scores[best_idx]
176
+ break
177
+ best_candidate = candidates[best_idx]
178
+ best_score = scores[best_idx]
179
+ current_text = best_candidate
180
+ output = best_candidate
181
+ candidate_info = "\n".join([f"[{i+1}] {c} (score={s:.3f})" for i, (c, s) in enumerate(zip(all_candidates, all_scores))])
182
+ elif strategy == "Phrase Replacement":
183
+ output = phrase_replacement(text, det["spans"])
184
+ candidate_info = "-"
185
+ elif strategy == "Unguided":
186
+ output = unguided_rewrite(text)
187
+ candidate_info = "-"
188
+ else:
189
+ output = text
190
+ candidate_info = "-"
191
+ # ์Šค์ฝ”์–ด ๊ณ„์‚ฐ
192
+ tox_score = calc_toxicity_reduction(text, output, detector_model, detector_tokenizer)
193
+ ppl_score = calc_ppl(output)
194
+ bert_score = calc_bertscore(text, output)
195
+ soft_or_hard = det.get("soft_or_hard", "-")
196
+ return (
197
+ det["category"],
198
+ str(det["spans"]),
199
+ str(det["targets"]),
200
+ output,
201
+ f"Toxicity Reduction: {tox_score}",
202
+ f"PPL: {ppl_score}",
203
+ f"BERTScore: {bert_score}",
204
+ candidate_info,
205
+ soft_or_hard
206
+ )
207
 
208
+ with gr.Blocks(theme=gr.themes.Monochrome(primary_hue="blue", secondary_hue="slate")) as demo:
209
+ gr.HTML("""
210
+ <style>
211
+ .modern-input textarea { font-size: 1.1em; border-radius: 8px; }
212
+ .modern-btn { background: linear-gradient(90deg,#2563eb,#60a5fa); color: white; border-radius: 8px; font-weight: bold; }
213
+ .modern-label { font-size: 1.05em; font-weight: 600; color: #2563eb; }
214
+ .modern-badge { background: #e0e7ff; color: #1e293b; border-radius: 6px; padding: 0.2em 0.6em; font-weight: 500; }
215
+ .modern-output textarea { background: #f1f5f9; border-radius: 8px; font-size: 1.1em; }
216
+ .modern-score input { color: #0ea5e9; font-weight: bold; font-size: 1.1em; }
217
+ .modern-candidate textarea { background: #f8fafc; border-radius: 8px; }
218
+ </style>
219
+ """)
220
+ gr.Markdown("""
221
+ <div style='text-align:center; margin-bottom: 1em;'>
222
+ <h1 style='color:#2563eb;'>Hate Speech Mitigation Demo</h1>
223
+ <p style='font-size:1.1em;'>Enter a sentence and select a mitigation strategy to see the results and various scores.</p>
224
+ </div>
225
+ """)
226
+ with gr.Row():
227
+ lang = gr.Radio(["Korean", "English"], value="Korean", label="Language", container=True)
228
+ input_box = gr.Textbox(label="Input text", lines=2, value="์•ผ, ๋ณ‘์‹ ์•„. ์ด๋Ÿฐ ๊ฒƒ๋„ ๋ชปํ•ด?", elem_classes="modern-input")
229
+ strategy = gr.Radio([
230
+ "Guided",
231
+ "Phrase Replacement",
232
+ "Unguided"
233
+ ], value="Full Rewrite", label="Mitigation Strategy", container=True)
234
+ run_btn = gr.Button("Non-Toxic", variant="primary", elem_classes="modern-btn")
235
+ with gr.Row():
236
+ out1 = gr.Label(label="Hate type", elem_classes="modern-label")
237
+ out2 = gr.HighlightedText(label="Detected spans", elem_classes="modern-badge")
238
+ out3 = gr.Label(label="Target", elem_classes="modern-label")
239
+ out5 = gr.Number(label="Toxicity Reduction", elem_classes="modern-score")
240
+ out6 = gr.Number(label="PPL", elem_classes="modern-score")
241
+ out7 = gr.Number(label="BERTScore", elem_classes="modern-score")
242
+ with gr.Accordion("Candidate Info", open=False):
243
+ out8 = gr.Textbox(label="Candidates", lines=3, elem_classes="modern-candidate")
244
+ out4 = gr.Textbox(label="output)", lines=2, elem_classes="modern-output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ def hate_speech_multilingual(text, strategy, lang):
247
+ if lang == "Korean":
248
+ return hate_speech_demo(text, strategy)[:8] # soft/hard ์ œ์™ธ
249
+ else:
250
+ # ์˜์–ด ํ˜์˜ค ํƒ์ง€ ํŒŒ์ดํ”„๋ผ์ธ ์ค€๋น„
251
+ from english_detector import detect_spans as detect_spans_en, EnglishElectraHateDetector
252
+ from transformers import AutoTokenizer
253
+ import torch
254
+ # ๋งคํ•‘ ์ƒ์„ฑ
255
+ tag2id = {"O": 0, "B-HATE": 1, "I-HATE": 2, "B-OFF": 3, "I-OFF": 4}
256
+ id2tag = {v: k for k, v in tag2id.items()}
257
+ sev2id = {"NORMAL": 0, "O1": 1, "O2": 2, "H1": 3, "H2": 4}
258
+ id2sev = {v: k for k, v in sev2id.items()}
259
+ # ์นดํ…Œ๊ณ ๋ฆฌ ์˜ˆ์‹œ(์‹ค์ œ HateXplain ์ „์ฒด ์นดํ…Œ๊ณ ๋ฆฌ ๋ฐ˜์˜ ํ•„์š”)
260
+ cat_list = ["African", "Arab", "Asian", "Atheist", "Buddhist", "Christian", "Female", "Hispanic", "Homosexual_gay_or_lesbian", "Immigrant", "Jewish", "Male", "Other_religions", "Physical_disability", "Transgender", "Other"]
261
+ cat2id = {c: i for i, c in enumerate(cat_list)}
262
+ id2cat = {v: k for k, v in cat2id.items()}
263
+ # ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„
264
+ base_model_name = "google/electra-base-discriminator"
265
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
266
+ model = EnglishElectraHateDetector(base_model_name, num_severity=5, num_bio_labels=5, num_targets=len(cat2id))
267
+ ckpt = torch.load("/root/PROJECT-ROOT/backend/english_detector_ckpt/best_model.pt", map_location="cuda")
268
+ model.load_state_dict(ckpt)
269
+ model.eval()
270
+ model.to("cuda") # ํƒ์ง€ ๋ชจ๋ธ๋งŒ GPU์— ์˜ฌ๋ฆผ
271
+ # ํƒ์ง€ ์‹คํ–‰
272
+ det = detect_spans_en(text, model, tokenizer, tag2id, id2tag, sev2id, id2sev, cat2id, id2cat, device="cuda")
273
+ # HighlightedText ๋ณ€ํ™˜: [(text, class)]
274
+ spans_for_highlight = [(span["text"], span["label"]) for span in det["spans"] if span.get("text")]
275
+ # ํƒ€๊ฒŸ ๋ณ€ํ™˜
276
+ targets_str = ", ".join(det["targets"]) if det["targets"] else "-"
277
+ # ์ „๋žต๋ณ„ ์ˆœํ™”(์—ฌ๊ธฐ์„  ์›๋ฌธ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜, ํ•„์š”์‹œ ์˜์–ด generator ์—ฐ๋™)
278
+ output = text
279
+ candidate_info = "-"
280
+ # ์ ์ˆ˜ placeholder
281
+ tox_score = "-"
282
+ ppl_score = "-"
283
+ bert_score = "-"
284
+ soft_or_hard = det.get("soft_or_hard", "-")
285
+ return (
286
+ det["category"] if "category" in det else det.get("severity", "-"),
287
+ spans_for_highlight,
288
+ targets_str,
289
+ tox_score,
290
+ ppl_score,
291
+ bert_score,
292
+ candidate_info,
293
+ output
294
+ )
295
 
296
+ run_btn.click(
297
+ hate_speech_multilingual,
298
+ inputs=[input_box, strategy, lang],
299
+ outputs=[out1, out2, out3, out5, out6, out7, out8, out4]
300
+ )
301
 
302
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ print("input_ids:", input_ids)
305
+ print("attention_mask:", attention_mask)
306
+ print("decode_mask:", decode_mask)
307
+ print("bio_feats.shape:", bio_feats.shape)