alohaboy commited on
Commit
f09939b
·
0 Parent(s):

최소 파일만 포함한 완전 클린 Space push

Browse files
Files changed (3) hide show
  1. README.md +14 -0
  2. app.py +648 -0
  3. requirements.txt +9 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Korean Hate Speech Mitigation Demo
3
+ emoji: "🛡️"
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: "4.44.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Korean Hate Speech Mitigation Demo
13
+
14
+ 이 Space는 한국어 혐오 표현 탐지 및 순화 데모입니다.
app.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
5
+ import numpy as np
6
+ from datetime import datetime
7
+ from TorchCRF import CRF
8
+
9
+ from bert_score import score as bert_score_fn
10
+ import re
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ def calc_bertscore(orig_text, rewritten_text):
14
+ P, R, F1 = bert_score_fn([rewritten_text], [orig_text], lang="ko")
15
+ return round(F1[0].item(), 3)
16
+
17
+ def calc_ppl(text):
18
+ try:
19
+ tokens = text.split()
20
+ if len(tokens) < 2:
21
+ return 1.0
22
+ word_count = len(tokens)
23
+ base_ppl = 50.0
24
+ length_factor = min(word_count / 10.0, 2.0)
25
+ complexity_factor = 1.0 + (len(set(tokens)) / word_count) * 0.5
26
+ ppl = base_ppl * length_factor * complexity_factor
27
+ return round(ppl, 3)
28
+ except Exception as e:
29
+ print(f"PPL calculation error: {e}")
30
+ return 1.0
31
+
32
+ def calc_toxicity_reduction(orig_text, rewritten_text, detector_model, detector_tokenizer):
33
+ try:
34
+ # Original toxicity score
35
+ orig_enc = detector_tokenizer(orig_text, return_tensors="pt", padding="max_length", max_length=128)
36
+ device = next(detector_model.parameters()).device
37
+ orig_input_ids = orig_enc["input_ids"].to(device)
38
+ orig_attention_mask = orig_enc["attention_mask"].to(device)
39
+ with torch.no_grad():
40
+ orig_out = detector_model(input_ids=orig_input_ids, attention_mask=orig_attention_mask)
41
+ orig_logits = orig_out["sentence_logits"][0]
42
+ orig_probs = torch.softmax(orig_logits, dim=-1)
43
+ orig_toxicity = 1.0 - orig_probs[0].item()
44
+ # Rewritten toxicity score
45
+ rewritten_enc = detector_tokenizer(rewritten_text, return_tensors="pt", padding="max_length", max_length=128)
46
+ rewritten_input_ids = rewritten_enc["input_ids"].to(device)
47
+ rewritten_attention_mask = rewritten_enc["attention_mask"].to(device)
48
+ with torch.no_grad():
49
+ rewritten_out = detector_model(input_ids=rewritten_input_ids, attention_mask=rewritten_attention_mask)
50
+ rewritten_logits = rewritten_out["sentence_logits"][0]
51
+ rewritten_probs = torch.softmax(rewritten_logits, dim=-1)
52
+ rewritten_toxicity = 1.0 - rewritten_probs[0].item()
53
+ delta = orig_toxicity - rewritten_toxicity
54
+ return round(delta, 3)
55
+ except Exception as e:
56
+ print(f"Toxicity reduction calculation error: {e}")
57
+ return 0.0
58
+
59
+ class HateSpeechDetector(nn.Module):
60
+ def __init__(self, model_name="beomi/KcELECTRA-base", num_sentence_labels=4, num_bio_labels=5, num_targets=9):
61
+ super().__init__()
62
+ self.config = AutoConfig.from_pretrained(model_name)
63
+ self.encoder = AutoModel.from_pretrained(model_name, config=self.config)
64
+ hidden_size = self.config.hidden_size
65
+ self.dropout = nn.Dropout(0.1)
66
+ self.classifier = nn.Linear(hidden_size, num_sentence_labels) # Sentence classification
67
+ self.bio_linear = nn.Linear(hidden_size, num_bio_labels) # BIO tagging
68
+ self.crf = CRF(num_bio_labels)
69
+ self.target_head = nn.Linear(hidden_size, num_targets) # Target classification
70
+
71
+ def forward(self, input_ids, attention_mask, bio_tags=None, sentence_labels=None, targets=None):
72
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
73
+ sequence_output = outputs.last_hidden_state
74
+ pooled_output = sequence_output[:, 0, :]
75
+ dropped = self.dropout(pooled_output)
76
+ sentence_logits = self.classifier(dropped)
77
+ bio_feats = self.bio_linear(sequence_output)
78
+ bio_loss = None
79
+ if bio_tags is not None:
80
+ mask = bio_tags != -100
81
+ log_likelihood = self.crf.forward(bio_feats, bio_tags, mask=mask)
82
+ bio_loss = -log_likelihood
83
+ tgt_dropped = self.dropout(pooled_output)
84
+ target_logits = self.target_head(tgt_dropped)
85
+ loss = 0.0
86
+ if sentence_labels is not None:
87
+ cls_loss = nn.CrossEntropyLoss()(sentence_logits, sentence_labels)
88
+ loss += cls_loss
89
+ if bio_loss is not None:
90
+ loss += bio_loss.sum()
91
+ if targets is not None:
92
+ bce_loss = nn.BCEWithLogitsLoss()(target_logits, targets)
93
+ loss += 2.0 * bce_loss
94
+ # CRF decode
95
+ if bio_tags is not None:
96
+ decode_mask = bio_tags != -100
97
+ else:
98
+ decode_mask = attention_mask.bool()
99
+ print("[DEBUG] bio_tags:", bio_tags)
100
+ print("[DEBUG] attention_mask.shape:", attention_mask.shape)
101
+ print("[DEBUG] decode_mask.shape:", decode_mask.shape)
102
+ print("[DEBUG] decode_mask[:, 0]:", decode_mask[:, 0] if decode_mask.dim() > 1 else decode_mask[0])
103
+ print("[DEBUG] bio_feats.shape:", bio_feats.shape)
104
+ bio_preds = self.crf.viterbi_decode(bio_feats, mask=decode_mask)
105
+ return {
106
+ 'loss': loss,
107
+ 'sentence_logits': sentence_logits,
108
+ 'bio_logits': bio_feats,
109
+ 'bio_preds': bio_preds,
110
+ 'target_logits': target_logits
111
+ }
112
+
113
+ class HateSpeechDetectorService:
114
+ def __init__(self):
115
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base")
117
+ self.model = HateSpeechDetector()
118
+
119
+ # Model loading
120
+ MODEL_CKPT_PATH = hf_hub_download(repo_id="alohaboy/hate_detector_ko", filename="best_model.pt")
121
+ checkpoint = torch.load(MODEL_CKPT_PATH, map_location=self.device)
122
+ # state_dict key conversion
123
+ key_map = {
124
+ 'sentence_classifier.weight': 'classifier.weight',
125
+ 'sentence_classifier.bias': 'classifier.bias',
126
+ 'bio_classifier.weight': 'bio_linear.weight',
127
+ 'bio_classifier.bias': 'bio_linear.bias',
128
+ # CRF related keys (reverse)
129
+ 'crf.transitions': 'crf.trans_matrix',
130
+ 'crf.start_transitions': 'crf.start_trans',
131
+ 'crf.end_transitions': 'crf.end_trans',
132
+ }
133
+ new_state_dict = {}
134
+ # If checkpoint is a dict and model_state_dict key exists, load from it
135
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
136
+ state_dict = checkpoint['model_state_dict']
137
+ else:
138
+ state_dict = checkpoint
139
+ for k, v in state_dict.items():
140
+ new_key = key_map.get(k, k)
141
+ new_state_dict[new_key] = v
142
+ self.model.load_state_dict(new_state_dict, strict=True)
143
+ self.model.to(self.device)
144
+ self.model.eval()
145
+
146
+ # Blossom LLM loading
147
+ print("Blossom LLM loading...")
148
+ self.llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-3B"
149
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
150
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
151
+ self.llm_model_name,
152
+ torch_dtype=torch.bfloat16,
153
+ device_map="auto"
154
+ )
155
+ print("LLM loading complete!")
156
+
157
+ self.label_names = ["normal", "offensive", "L1_hate", "L2_hate"]
158
+ self.bio_names = {0: "O", 1: "B-SOFT", 2: "I-SOFT", 3: "B-HARD", 4: "I-HARD"}
159
+
160
+ val_acc = checkpoint['val_acc'] if 'val_acc' in checkpoint else None
161
+ if val_acc is not None:
162
+ print(f"Model loaded - Validation accuracy: {val_acc:.2f}%")
163
+ else:
164
+ print("Model loaded - Validation accuracy: N/A")
165
+
166
+ def detect_hate_speech(self, text, strategy="Detection Only"):
167
+ """Hate Speech Detection and Mitigation"""
168
+ if not text.strip():
169
+ return "Please enter text", ""
170
+ if len(text.strip()) < 2:
171
+ return "Input text is too short. Please enter at least 2 characters.", ""
172
+
173
+ if strategy == "Detection Only":
174
+ result_msg, mitigation, debug_info = self._detection_only(text)
175
+ print("[DEBUG] Input text:", text)
176
+ print("[DEBUG] sentence_logits:", debug_info.get('sentence_logits'))
177
+ print("[DEBUG] sentence_probs:", debug_info.get('sentence_probs'))
178
+ print("[DEBUG] sentence_pred:", debug_info.get('sentence_pred'))
179
+ print("[DEBUG] label:", debug_info.get('label'))
180
+ print("[DEBUG] confidence:", debug_info.get('confidence'))
181
+ return result_msg, mitigation
182
+ elif strategy == "Guided":
183
+ return self._guided_mitigation(text)
184
+ elif strategy == "Guided+Reflect":
185
+ return self._guided_reflect_mitigation(text)
186
+ elif strategy == "Unguided":
187
+ return self._unguided_mitigation(text)
188
+ else:
189
+ return "Invalid strategy", ""
190
+
191
+ def _detection_only(self, text):
192
+ """Perform only detection (existing logic)"""
193
+ # Tokenization
194
+ encoding = self.tokenizer(
195
+ text,
196
+ truncation=True,
197
+ padding="max_length",
198
+ max_length=128,
199
+ return_attention_mask=True,
200
+ return_tensors="pt"
201
+ )
202
+
203
+ input_ids = encoding["input_ids"].to(self.device)
204
+ attention_mask = encoding["attention_mask"].to(self.device)
205
+ print("[DEBUG] attention_mask[:, 0] =", attention_mask[:, 0])
206
+
207
+ # Prediction
208
+ with torch.no_grad():
209
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
210
+ sentence_logits = outputs["sentence_logits"]
211
+ bio_logits = outputs["bio_logits"]
212
+
213
+ # Sentence classification result
214
+ sentence_probs = torch.softmax(sentence_logits, dim=1)
215
+ sentence_pred = torch.argmax(sentence_logits, dim=1).item()
216
+ sentence_prob = sentence_probs[0][sentence_pred].item()
217
+
218
+ # BIO tagging result
219
+ bio_preds = torch.argmax(bio_logits, dim=2)[0]
220
+
221
+ # Find hate/aggressive tokens
222
+ hate_tokens = []
223
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
224
+
225
+ # Tokenize original text to get offset mapping
226
+ tokenized = self.tokenizer(
227
+ text,
228
+ truncation=True,
229
+ padding="max_length",
230
+ max_length=128,
231
+ return_offsets_mapping=True
232
+ )
233
+ offset_mapping = tokenized["offset_mapping"]
234
+
235
+ for j, (token, pred) in enumerate(zip(tokens, bio_preds)):
236
+ if pred.item() != 0: # Not O
237
+ # Extract the corresponding part from the original text using offset mapping
238
+ if j < len(offset_mapping):
239
+ start, end = offset_mapping[j]
240
+ if start != end: # Token mapped to actual text
241
+ original_text = text[start:end]
242
+ hate_tokens.append((j, original_text, self.bio_names[pred.item()]))
243
+ else:
244
+ # Special token handling
245
+ if token.startswith('Ġ'):
246
+ decoded_token = token[1:] # Remove Ġ
247
+ elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
248
+ decoded_token = token
249
+ else:
250
+ decoded_token = token
251
+ hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
252
+ else:
253
+ # Fallback
254
+ if token.startswith('Ġ'):
255
+ decoded_token = token[1:]
256
+ elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
257
+ decoded_token = token
258
+ else:
259
+ decoded_token = token
260
+ hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
261
+
262
+ # Determine label
263
+ label = self.label_names[sentence_pred]
264
+ # If hate_tokens contain B-HARD, I-HARD, increase label to L2_hate
265
+ if any(bio_label in ["B-HARD", "I-HARD"] for _, _, bio_label in hate_tokens):
266
+ label = "L2_hate"
267
+ # Construct result message
268
+ result_msg = f"Detection result: {label}\nConfidence: {sentence_prob:.2f}"
269
+ if hate_tokens:
270
+ result_msg += f"\nIdentified hate/aggressive expressions: {hate_tokens}"
271
+ mitigation = "Performed only detection."
272
+ debug_info = {
273
+ 'sentence_logits': sentence_logits,
274
+ 'sentence_probs': sentence_probs,
275
+ 'sentence_pred': sentence_pred,
276
+ 'label': label,
277
+ 'confidence': sentence_prob,
278
+ 'hate_tokens': hate_tokens
279
+ }
280
+ return result_msg, mitigation, debug_info
281
+
282
+ def _unguided_mitigation(self, text):
283
+ """Unguided Mode: Only Using Generation"""
284
+ try:
285
+ # Blossom LLM prompt
286
+ prompt = f"""Please remove hate speech or aggressive expressions from the following sentence, while maintaining the original intent (criticism, complaint, opinion, etc.).
287
+
288
+ Original: {text}
289
+
290
+ Mitigated sentence:"""
291
+
292
+ # LLM inference
293
+ inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
294
+
295
+ with torch.no_grad():
296
+ outputs = self.llm_model.generate(
297
+ **inputs,
298
+ do_sample=True,
299
+ top_k=50,
300
+ top_p=0.9,
301
+ max_new_tokens=300,
302
+ pad_token_id=self.llm_tokenizer.pad_token_id,
303
+ eos_token_id=self.llm_tokenizer.eos_token_id
304
+ )
305
+
306
+ # Decode result
307
+ full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
308
+
309
+ # Remove prompt part and extract mitigated sentence
310
+ mitigated_text = full_response.replace(prompt, "").strip()
311
+
312
+ # Handle truncated sentences
313
+ if len(mitigated_text) < 10: # Too short, use original response
314
+ mitigated_text = full_response
315
+
316
+ # Prevent repetitive output: extract only the first mitigated sentence
317
+ if "Mitigated sentence:" in mitigated_text:
318
+ mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
319
+
320
+ # Use only the first meaningful line if multiple lines
321
+ lines = mitigated_text.split('\n')
322
+ clean_lines = []
323
+ for line in lines:
324
+ line = line.strip()
325
+ if line and not line.startswith('**') and not line.startswith('Original:'):
326
+ clean_lines.append(line)
327
+
328
+ if clean_lines:
329
+ mitigated_text = clean_lines[0]
330
+
331
+ # Result message
332
+ result_msg = f"🤖 **Blossom LLM Mitigation Result**\n\n"
333
+ result_msg += f"**Original:** {text}\n\n"
334
+ result_msg += f"**Mitigated Sentence:** {mitigated_text}"
335
+
336
+ # Mitigation info
337
+ mitigation = "**Unguided Mode:** Blossom LLM detected and mitigated harmful expressions autonomously."
338
+
339
+ return result_msg, mitigation
340
+
341
+ except Exception as e:
342
+ error_msg = f"❌ **Blossom LLM Error**\n\nError occurred: {str(e)}"
343
+ return error_msg, "An error occurred during LLM processing."
344
+
345
+ def _guided_mitigation(self, text):
346
+ """Guided Mode: Mitigate based on KcELECTRA detection result using Blossom LLM"""
347
+ try:
348
+ # First, perform detection with KcELECTRA
349
+ detection_result, _, debug_info = self._detection_only(text)
350
+ label = debug_info.get('label', 'normal')
351
+ hate_tokens = debug_info.get('hate_tokens', [])
352
+
353
+ # Construct Blossom LLM prompt
354
+ if label == "normal":
355
+ prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
356
+ else:
357
+ label_desc = {
358
+ "offensive": "Aggressive",
359
+ "L1_hate": "Mild Hate",
360
+ "L2_hate": "Severe Hate"
361
+ }
362
+ hate_tokens_str = ""
363
+ if hate_tokens:
364
+ hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"• {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
365
+ prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nPlease remove hate speech or aggressive expressions, while maintaining the original intent (criticism, complaint, opinion, etc.).\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\n[Important] All offensive, derogatory, and explicit hate expressions (e.g., 씨발, 좆, 병신) must be deleted.\n\nMitigated sentence:"""
366
+ # LLM inference
367
+ inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
368
+ with torch.no_grad():
369
+ outputs = self.llm_model.generate(
370
+ **inputs,
371
+ do_sample=True,
372
+ top_k=50,
373
+ top_p=0.9,
374
+ max_new_tokens=300,
375
+ pad_token_id=self.llm_tokenizer.pad_token_id,
376
+ eos_token_id=self.llm_tokenizer.eos_token_id
377
+ )
378
+ full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
379
+ mitigated_text = full_response.replace(prompt, "").strip()
380
+ if len(mitigated_text) < 10:
381
+ mitigated_text = full_response
382
+ if "Mitigated sentence:" in mitigated_text:
383
+ mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
384
+ lines = mitigated_text.split('\n')
385
+ clean_lines = []
386
+ for line in lines:
387
+ line = line.strip()
388
+ if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
389
+ clean_lines.append(line)
390
+ if clean_lines:
391
+ mitigated_text = clean_lines[0]
392
+ result_msg = f"🎯 **Guided Mitigation Result**\n\n"
393
+ result_msg += f"**KcELECTRA Detection Result:**\n{detection_result}\n\n"
394
+ result_msg += f"**Blossom LLM Mitigation Result:**\n{mitigated_text}"
395
+ mitigation = "**Guided Mode:** Blossom LLM performed specific mitigation based on KcELECTRA's detection information."
396
+ return result_msg, mitigation
397
+ except Exception as e:
398
+ error_msg = f"❌ **Guided Mitigation Error**\n\nError occurred: {str(e)}"
399
+ return error_msg, "An error occurred during guided mitigation processing."
400
+
401
+ def _guided_reflect_mitigation(self, text):
402
+ """Guided+Reflect Mode: iterative refinement + critic evaluation"""
403
+ try:
404
+ detection_result, _, debug_info = self._detection_only(text)
405
+ label = debug_info.get('label', 'normal')
406
+ hate_tokens = debug_info.get('hate_tokens', [])
407
+ # Step 1: Initial mitigation
408
+ if label == "normal":
409
+ initial_prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
410
+ else:
411
+ label_desc = {
412
+ "offensive": "Aggressive",
413
+ "L1_hate": "Mild Hate",
414
+ "L2_hate": "Severe Hate"
415
+ }
416
+ hate_tokens_str = ""
417
+ if hate_tokens:
418
+ hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"• {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
419
+ initial_prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nExpressions containing offensive words (e.g., 좃, 씨발, 병신) must be deleted.\nOther aggressive or inappropriate expressions should be mitigated by expressing them more politely and inclusively.\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\nMitigated sentence:"""
420
+ # Iterative mitigation and evaluation
421
+ max_iter = 5
422
+ metrics_history = []
423
+ best_candidate = None
424
+ best_score = -float('inf')
425
+ current_input = text
426
+ for i in range(max_iter):
427
+ # Generate candidate
428
+ inputs = self.llm_tokenizer(initial_prompt, return_tensors="pt").to(self.llm_model.device)
429
+ with torch.no_grad():
430
+ outputs = self.llm_model.generate(
431
+ **inputs,
432
+ do_sample=True,
433
+ top_k=50,
434
+ top_p=0.9,
435
+ max_new_tokens=300,
436
+ pad_token_id=self.llm_tokenizer.pad_token_id,
437
+ eos_token_id=self.llm_tokenizer.eos_token_id
438
+ )
439
+ candidate = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
440
+ mitigated_text = candidate.replace(initial_prompt, "").strip()
441
+ if len(mitigated_text) < 10:
442
+ mitigated_text = candidate
443
+ if "Mitigated sentence:" in mitigated_text:
444
+ mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
445
+ lines = mitigated_text.split('\n')
446
+ clean_lines = []
447
+ for line in lines:
448
+ line = line.strip()
449
+ if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
450
+ clean_lines.append(line)
451
+ if clean_lines:
452
+ mitigated_text = clean_lines[0]
453
+ # Exclude candidates containing offensive words
454
+ if contains_badword(mitigated_text):
455
+ continue
456
+ # Evaluation
457
+ toxicity = calc_toxicity_reduction(text, mitigated_text, self.model, self.tokenizer)
458
+ bertscore = calc_bertscore(text, mitigated_text)
459
+ ppl = calc_ppl(mitigated_text)
460
+ metrics_history.append({'iteration': i+1, 'candidate': mitigated_text, 'toxicity': toxicity, 'bertscore': bertscore, 'ppl': ppl})
461
+ # Simple combined score (weight adjustment possible)
462
+ total_score = toxicity + bertscore - ppl * 0.01
463
+ if total_score > best_score:
464
+ best_score = total_score
465
+ best_candidate = mitigated_text
466
+ # Early termination criteria (e.g., toxicity>0.3, bertscore>0.7, ppl<100)
467
+ if toxicity > 0.3 and bertscore > 0.7 and ppl < 100:
468
+ break
469
+ # Log output
470
+ iter_log_str = ""
471
+ for log in metrics_history:
472
+ iter_log_str += f"\nIteration {log['iteration']}:\n- Candidate: {log['candidate']}\n- Toxicity reduction: {log['toxicity']}, bertscore: {log['bertscore']}, ppl: {log['ppl']}"
473
+ # Result message
474
+ result_msg = f"🔄 **Guided+Reflect Mitigation Result**\n\n"
475
+ result_msg += f"**Detection Result:**\n{detection_result}\n\n"
476
+ result_msg += f"**Iterative Mitigation Log:**{iter_log_str}\n\n"
477
+ result_msg += f"**Best Mitigation:** {best_candidate}"
478
+ mitigation = "**Guided+Reflect Mode:** Selected the optimal candidate after iterative mitigation and evaluation (maximum 5 iterations)."
479
+ return result_msg, mitigation
480
+ except Exception as e:
481
+ error_msg = f"❌ **Guided+Reflect Mitigation Error**\n\nError occurred: {str(e)}"
482
+ return error_msg, "An error occurred during guided+reflect mitigation processing."
483
+
484
+ def _suggest_mitigation(self, label, confidence, hate_tokens):
485
+ """Suggest mitigation for hate speech expressions"""
486
+ if label == "normal":
487
+ return "✅ **Mitigation Suggestion**: This sentence does not require correction."
488
+
489
+ mitigation = f"**🔧 Mitigation Suggestion for Hate Speech:**\n\n"
490
+
491
+ if label == "offensive":
492
+ mitigation += "**Aggressive Expression Mitigation Options:**\n"
493
+ mitigation += "• Try to change aggressive expressions to more polite expressions\n"
494
+ mitigation += "• Use objective expressions instead of emotional expressions\n"
495
+ mitigation += "• Reconstruct with a mind to be considerate\n"
496
+ mitigation += "• When criticizing, provide specific and constructive feedback"
497
+ elif label == "L1_hate":
498
+ mitigation += "**Implicit Hate Expression Mitigation Options:**\n"
499
+ mitigation += "• Remove expressions that discriminate or show prejudice\n"
500
+ mitigation += "• Avoid generalizing about specific groups\n"
501
+ mitigation += "• Use more inclusive and respectful expressions\n"
502
+ mitigation += "• Change to expressions that acknowledge diversity"
503
+ else: # L2_hate
504
+ mitigation += "**Explicit Hate Expression Mitigation Options:**\n"
505
+ mitigation += "• Completely remove severe hate expressions\n"
506
+ mitigation += "• Do not use violent or threatening expressions\n"
507
+ mitigation += "• Use expressions that respect everyone's dignity\n"
508
+ mitigation += "• Change to expressions that discriminate or promote hate\n"
509
+ mitigation += "• If necessary, seek professional help"
510
+
511
+ return mitigation
512
+
513
+ def contains_badword(text):
514
+ badwords = ["좃", "씨발", "병신", "개새끼", "염병", "좆", "ㅅㅂ", "ㅄ", "ㅂㅅ", "ㅗ", "ㅉ"]
515
+ return any(bad in text for bad in badwords)
516
+
517
+ # Service initialization
518
+ service = HateSpeechDetectorService()
519
+
520
+ # Gradio interface
521
+ def create_demo():
522
+ with gr.Blocks(
523
+ title="Korean Hate Speech Detection and Mitigation System",
524
+ theme=gr.themes.Soft(),
525
+ css="""
526
+ .gradio-container {
527
+ max-width: 800px;
528
+ margin: 0 auto;
529
+ }
530
+ .result-box {
531
+ border-radius: 10px;
532
+ padding: 15px;
533
+ margin: 10px 0;
534
+ }
535
+ .normal { background-color: #d4edda; border: 1px solid #c3e6cb; }
536
+ .offensive { background-color: #fff3cd; border: 1px solid #ffeaa7; }
537
+ .hate { background-color: #f8d7da; border: 1px solid #f5c6cb; }
538
+ """
539
+ ) as demo:
540
+ gr.Markdown("""
541
+ # Korean Hate Speech Detection and Mitigation System
542
+
543
+ This system detects hate speech in Korean text and provides mitigation suggestions.
544
+
545
+
546
+ **🟢 Normal**:
547
+ - It is a normal sentence.
548
+
549
+ **🟡 Offensive**
550
+
551
+ - For example: "Don't say such a stupid thing", "How can you do such a stupid thing"
552
+
553
+ **🟠 L1_hate (Implicit Hate)**: Mild hate expression
554
+ - **Implicit hate expression** for protected attribute groups
555
+ - For example: "Those people are all the same", "Prejudicial expression towards a specific group"
556
+
557
+ **🔴 L2_hate (Explicit Hate)**: Severe hate expression
558
+ - **Explicit hate expression** for protected attribute groups
559
+
560
+ **🤖 Mitigation Mode:**
561
+ - 🔍 **Detection Only**: Hate Speech Detection Only
562
+ - 🎯 **Guided**: Guided Mitigation
563
+ - 🔄 **Guided+Reflect**: After Guided Mitigation, Iterative Refinement
564
+ - 🤖 **Unguided**: LLM generates text without any guidance
565
+ """)
566
+
567
+ with gr.Row():
568
+ with gr.Column(scale=2):
569
+ input_text = gr.Textbox(
570
+ label="Enter text to detect hate speech & mitigate",
571
+ lines=3
572
+ )
573
+
574
+ strategy = gr.Radio(
575
+ ["Detection Only", "Guided", "Guided+Reflect", "Unguided"],
576
+ value="Detection Only",
577
+ label="Select Mitigation Mode",
578
+ container=True
579
+ )
580
+
581
+ analyze_btn = gr.Button("🔍 Detect & Mitigate", variant="primary", size="lg")
582
+
583
+ # with gr.Column(scale=1):
584
+ # gr.Markdown("""
585
+ # **🧪 Test Examples:**
586
+
587
+ # **🟢 Normal:**
588
+ # - "Hello, today's weather is nice."
589
+ # - "This movie was really fun."
590
+
591
+ # **🟡 Offensive:**
592
+ # - "How can you do such a stupid thing"
593
+ # - "Don't say such a stupid thing"
594
+
595
+ # **🟠 L1_hate (Implicit):**
596
+ # - "Those people are all the same"
597
+ # - "Prejudicial expression towards a specific group"
598
+
599
+ # **🔴 L2_hate (Explicit):**
600
+ # - "All women are useless"
601
+ # - "People with disabilities are a burden to society"
602
+ # """)
603
+
604
+ with gr.Row():
605
+ with gr.Column():
606
+ result_output = gr.Markdown(
607
+ label="Mitigation Button",
608
+ value="Input text and click the above button."
609
+ )
610
+
611
+ with gr.Column():
612
+ mitigation_output = gr.Markdown(
613
+ label="Mitigation Suggestion",
614
+ value="Based on the analysis result, mitigation suggestions will be provided."
615
+ )
616
+
617
+ # Event handlers
618
+ analyze_btn.click(
619
+ fn=service.detect_hate_speech,
620
+ inputs=[input_text, strategy],
621
+ outputs=[result_output, mitigation_output]
622
+ )
623
+
624
+ # Allow analysis via Enter key
625
+ input_text.submit(
626
+ fn=service.detect_hate_speech,
627
+ inputs=[input_text, strategy],
628
+ outputs=[result_output, mitigation_output]
629
+ )
630
+
631
+ # gr.Markdown("""
632
+ # ---
633
+ # **Model Information:**
634
+ # - Detection Model: KcELECTRA-base (Validation Accuracy: 67.67%)
635
+ # - Mitigation Model: Blossom LLM (llama-3.2-Korean-Bllossom-3B)
636
+ # - Training Data: K-HATERS Dataset
637
+ # """)
638
+
639
+ return demo
640
+
641
+ if __name__ == "__main__":
642
+ demo = create_demo()
643
+ demo.launch(
644
+ server_name="0.0.0.0",
645
+ server_port=7863,
646
+ share=True,
647
+ show_error=True
648
+ )
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ bert-score>=0.3.13
5
+ numpy>=1.21.0
6
+ scikit-learn>=1.0.0
7
+ accelerate>=0.20.0
8
+ TorchCRF==1.1.0
9
+