Commit
·
f09939b
0
Parent(s):
최소 파일만 포함한 완전 클린 Space push
Browse files- README.md +14 -0
- app.py +648 -0
- requirements.txt +9 -0
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Korean Hate Speech Mitigation Demo
|
3 |
+
emoji: "🛡️"
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: "4.44.0"
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# Korean Hate Speech Mitigation Demo
|
13 |
+
|
14 |
+
이 Space는 한국어 혐오 표현 탐지 및 순화 데모입니다.
|
app.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
|
5 |
+
import numpy as np
|
6 |
+
from datetime import datetime
|
7 |
+
from TorchCRF import CRF
|
8 |
+
|
9 |
+
from bert_score import score as bert_score_fn
|
10 |
+
import re
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
def calc_bertscore(orig_text, rewritten_text):
|
14 |
+
P, R, F1 = bert_score_fn([rewritten_text], [orig_text], lang="ko")
|
15 |
+
return round(F1[0].item(), 3)
|
16 |
+
|
17 |
+
def calc_ppl(text):
|
18 |
+
try:
|
19 |
+
tokens = text.split()
|
20 |
+
if len(tokens) < 2:
|
21 |
+
return 1.0
|
22 |
+
word_count = len(tokens)
|
23 |
+
base_ppl = 50.0
|
24 |
+
length_factor = min(word_count / 10.0, 2.0)
|
25 |
+
complexity_factor = 1.0 + (len(set(tokens)) / word_count) * 0.5
|
26 |
+
ppl = base_ppl * length_factor * complexity_factor
|
27 |
+
return round(ppl, 3)
|
28 |
+
except Exception as e:
|
29 |
+
print(f"PPL calculation error: {e}")
|
30 |
+
return 1.0
|
31 |
+
|
32 |
+
def calc_toxicity_reduction(orig_text, rewritten_text, detector_model, detector_tokenizer):
|
33 |
+
try:
|
34 |
+
# Original toxicity score
|
35 |
+
orig_enc = detector_tokenizer(orig_text, return_tensors="pt", padding="max_length", max_length=128)
|
36 |
+
device = next(detector_model.parameters()).device
|
37 |
+
orig_input_ids = orig_enc["input_ids"].to(device)
|
38 |
+
orig_attention_mask = orig_enc["attention_mask"].to(device)
|
39 |
+
with torch.no_grad():
|
40 |
+
orig_out = detector_model(input_ids=orig_input_ids, attention_mask=orig_attention_mask)
|
41 |
+
orig_logits = orig_out["sentence_logits"][0]
|
42 |
+
orig_probs = torch.softmax(orig_logits, dim=-1)
|
43 |
+
orig_toxicity = 1.0 - orig_probs[0].item()
|
44 |
+
# Rewritten toxicity score
|
45 |
+
rewritten_enc = detector_tokenizer(rewritten_text, return_tensors="pt", padding="max_length", max_length=128)
|
46 |
+
rewritten_input_ids = rewritten_enc["input_ids"].to(device)
|
47 |
+
rewritten_attention_mask = rewritten_enc["attention_mask"].to(device)
|
48 |
+
with torch.no_grad():
|
49 |
+
rewritten_out = detector_model(input_ids=rewritten_input_ids, attention_mask=rewritten_attention_mask)
|
50 |
+
rewritten_logits = rewritten_out["sentence_logits"][0]
|
51 |
+
rewritten_probs = torch.softmax(rewritten_logits, dim=-1)
|
52 |
+
rewritten_toxicity = 1.0 - rewritten_probs[0].item()
|
53 |
+
delta = orig_toxicity - rewritten_toxicity
|
54 |
+
return round(delta, 3)
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Toxicity reduction calculation error: {e}")
|
57 |
+
return 0.0
|
58 |
+
|
59 |
+
class HateSpeechDetector(nn.Module):
|
60 |
+
def __init__(self, model_name="beomi/KcELECTRA-base", num_sentence_labels=4, num_bio_labels=5, num_targets=9):
|
61 |
+
super().__init__()
|
62 |
+
self.config = AutoConfig.from_pretrained(model_name)
|
63 |
+
self.encoder = AutoModel.from_pretrained(model_name, config=self.config)
|
64 |
+
hidden_size = self.config.hidden_size
|
65 |
+
self.dropout = nn.Dropout(0.1)
|
66 |
+
self.classifier = nn.Linear(hidden_size, num_sentence_labels) # Sentence classification
|
67 |
+
self.bio_linear = nn.Linear(hidden_size, num_bio_labels) # BIO tagging
|
68 |
+
self.crf = CRF(num_bio_labels)
|
69 |
+
self.target_head = nn.Linear(hidden_size, num_targets) # Target classification
|
70 |
+
|
71 |
+
def forward(self, input_ids, attention_mask, bio_tags=None, sentence_labels=None, targets=None):
|
72 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
73 |
+
sequence_output = outputs.last_hidden_state
|
74 |
+
pooled_output = sequence_output[:, 0, :]
|
75 |
+
dropped = self.dropout(pooled_output)
|
76 |
+
sentence_logits = self.classifier(dropped)
|
77 |
+
bio_feats = self.bio_linear(sequence_output)
|
78 |
+
bio_loss = None
|
79 |
+
if bio_tags is not None:
|
80 |
+
mask = bio_tags != -100
|
81 |
+
log_likelihood = self.crf.forward(bio_feats, bio_tags, mask=mask)
|
82 |
+
bio_loss = -log_likelihood
|
83 |
+
tgt_dropped = self.dropout(pooled_output)
|
84 |
+
target_logits = self.target_head(tgt_dropped)
|
85 |
+
loss = 0.0
|
86 |
+
if sentence_labels is not None:
|
87 |
+
cls_loss = nn.CrossEntropyLoss()(sentence_logits, sentence_labels)
|
88 |
+
loss += cls_loss
|
89 |
+
if bio_loss is not None:
|
90 |
+
loss += bio_loss.sum()
|
91 |
+
if targets is not None:
|
92 |
+
bce_loss = nn.BCEWithLogitsLoss()(target_logits, targets)
|
93 |
+
loss += 2.0 * bce_loss
|
94 |
+
# CRF decode
|
95 |
+
if bio_tags is not None:
|
96 |
+
decode_mask = bio_tags != -100
|
97 |
+
else:
|
98 |
+
decode_mask = attention_mask.bool()
|
99 |
+
print("[DEBUG] bio_tags:", bio_tags)
|
100 |
+
print("[DEBUG] attention_mask.shape:", attention_mask.shape)
|
101 |
+
print("[DEBUG] decode_mask.shape:", decode_mask.shape)
|
102 |
+
print("[DEBUG] decode_mask[:, 0]:", decode_mask[:, 0] if decode_mask.dim() > 1 else decode_mask[0])
|
103 |
+
print("[DEBUG] bio_feats.shape:", bio_feats.shape)
|
104 |
+
bio_preds = self.crf.viterbi_decode(bio_feats, mask=decode_mask)
|
105 |
+
return {
|
106 |
+
'loss': loss,
|
107 |
+
'sentence_logits': sentence_logits,
|
108 |
+
'bio_logits': bio_feats,
|
109 |
+
'bio_preds': bio_preds,
|
110 |
+
'target_logits': target_logits
|
111 |
+
}
|
112 |
+
|
113 |
+
class HateSpeechDetectorService:
|
114 |
+
def __init__(self):
|
115 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
116 |
+
self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base")
|
117 |
+
self.model = HateSpeechDetector()
|
118 |
+
|
119 |
+
# Model loading
|
120 |
+
MODEL_CKPT_PATH = hf_hub_download(repo_id="alohaboy/hate_detector_ko", filename="best_model.pt")
|
121 |
+
checkpoint = torch.load(MODEL_CKPT_PATH, map_location=self.device)
|
122 |
+
# state_dict key conversion
|
123 |
+
key_map = {
|
124 |
+
'sentence_classifier.weight': 'classifier.weight',
|
125 |
+
'sentence_classifier.bias': 'classifier.bias',
|
126 |
+
'bio_classifier.weight': 'bio_linear.weight',
|
127 |
+
'bio_classifier.bias': 'bio_linear.bias',
|
128 |
+
# CRF related keys (reverse)
|
129 |
+
'crf.transitions': 'crf.trans_matrix',
|
130 |
+
'crf.start_transitions': 'crf.start_trans',
|
131 |
+
'crf.end_transitions': 'crf.end_trans',
|
132 |
+
}
|
133 |
+
new_state_dict = {}
|
134 |
+
# If checkpoint is a dict and model_state_dict key exists, load from it
|
135 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
136 |
+
state_dict = checkpoint['model_state_dict']
|
137 |
+
else:
|
138 |
+
state_dict = checkpoint
|
139 |
+
for k, v in state_dict.items():
|
140 |
+
new_key = key_map.get(k, k)
|
141 |
+
new_state_dict[new_key] = v
|
142 |
+
self.model.load_state_dict(new_state_dict, strict=True)
|
143 |
+
self.model.to(self.device)
|
144 |
+
self.model.eval()
|
145 |
+
|
146 |
+
# Blossom LLM loading
|
147 |
+
print("Blossom LLM loading...")
|
148 |
+
self.llm_model_name = "Bllossom/llama-3.2-Korean-Bllossom-3B"
|
149 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
|
150 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
151 |
+
self.llm_model_name,
|
152 |
+
torch_dtype=torch.bfloat16,
|
153 |
+
device_map="auto"
|
154 |
+
)
|
155 |
+
print("LLM loading complete!")
|
156 |
+
|
157 |
+
self.label_names = ["normal", "offensive", "L1_hate", "L2_hate"]
|
158 |
+
self.bio_names = {0: "O", 1: "B-SOFT", 2: "I-SOFT", 3: "B-HARD", 4: "I-HARD"}
|
159 |
+
|
160 |
+
val_acc = checkpoint['val_acc'] if 'val_acc' in checkpoint else None
|
161 |
+
if val_acc is not None:
|
162 |
+
print(f"Model loaded - Validation accuracy: {val_acc:.2f}%")
|
163 |
+
else:
|
164 |
+
print("Model loaded - Validation accuracy: N/A")
|
165 |
+
|
166 |
+
def detect_hate_speech(self, text, strategy="Detection Only"):
|
167 |
+
"""Hate Speech Detection and Mitigation"""
|
168 |
+
if not text.strip():
|
169 |
+
return "Please enter text", ""
|
170 |
+
if len(text.strip()) < 2:
|
171 |
+
return "Input text is too short. Please enter at least 2 characters.", ""
|
172 |
+
|
173 |
+
if strategy == "Detection Only":
|
174 |
+
result_msg, mitigation, debug_info = self._detection_only(text)
|
175 |
+
print("[DEBUG] Input text:", text)
|
176 |
+
print("[DEBUG] sentence_logits:", debug_info.get('sentence_logits'))
|
177 |
+
print("[DEBUG] sentence_probs:", debug_info.get('sentence_probs'))
|
178 |
+
print("[DEBUG] sentence_pred:", debug_info.get('sentence_pred'))
|
179 |
+
print("[DEBUG] label:", debug_info.get('label'))
|
180 |
+
print("[DEBUG] confidence:", debug_info.get('confidence'))
|
181 |
+
return result_msg, mitigation
|
182 |
+
elif strategy == "Guided":
|
183 |
+
return self._guided_mitigation(text)
|
184 |
+
elif strategy == "Guided+Reflect":
|
185 |
+
return self._guided_reflect_mitigation(text)
|
186 |
+
elif strategy == "Unguided":
|
187 |
+
return self._unguided_mitigation(text)
|
188 |
+
else:
|
189 |
+
return "Invalid strategy", ""
|
190 |
+
|
191 |
+
def _detection_only(self, text):
|
192 |
+
"""Perform only detection (existing logic)"""
|
193 |
+
# Tokenization
|
194 |
+
encoding = self.tokenizer(
|
195 |
+
text,
|
196 |
+
truncation=True,
|
197 |
+
padding="max_length",
|
198 |
+
max_length=128,
|
199 |
+
return_attention_mask=True,
|
200 |
+
return_tensors="pt"
|
201 |
+
)
|
202 |
+
|
203 |
+
input_ids = encoding["input_ids"].to(self.device)
|
204 |
+
attention_mask = encoding["attention_mask"].to(self.device)
|
205 |
+
print("[DEBUG] attention_mask[:, 0] =", attention_mask[:, 0])
|
206 |
+
|
207 |
+
# Prediction
|
208 |
+
with torch.no_grad():
|
209 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
210 |
+
sentence_logits = outputs["sentence_logits"]
|
211 |
+
bio_logits = outputs["bio_logits"]
|
212 |
+
|
213 |
+
# Sentence classification result
|
214 |
+
sentence_probs = torch.softmax(sentence_logits, dim=1)
|
215 |
+
sentence_pred = torch.argmax(sentence_logits, dim=1).item()
|
216 |
+
sentence_prob = sentence_probs[0][sentence_pred].item()
|
217 |
+
|
218 |
+
# BIO tagging result
|
219 |
+
bio_preds = torch.argmax(bio_logits, dim=2)[0]
|
220 |
+
|
221 |
+
# Find hate/aggressive tokens
|
222 |
+
hate_tokens = []
|
223 |
+
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
|
224 |
+
|
225 |
+
# Tokenize original text to get offset mapping
|
226 |
+
tokenized = self.tokenizer(
|
227 |
+
text,
|
228 |
+
truncation=True,
|
229 |
+
padding="max_length",
|
230 |
+
max_length=128,
|
231 |
+
return_offsets_mapping=True
|
232 |
+
)
|
233 |
+
offset_mapping = tokenized["offset_mapping"]
|
234 |
+
|
235 |
+
for j, (token, pred) in enumerate(zip(tokens, bio_preds)):
|
236 |
+
if pred.item() != 0: # Not O
|
237 |
+
# Extract the corresponding part from the original text using offset mapping
|
238 |
+
if j < len(offset_mapping):
|
239 |
+
start, end = offset_mapping[j]
|
240 |
+
if start != end: # Token mapped to actual text
|
241 |
+
original_text = text[start:end]
|
242 |
+
hate_tokens.append((j, original_text, self.bio_names[pred.item()]))
|
243 |
+
else:
|
244 |
+
# Special token handling
|
245 |
+
if token.startswith('Ġ'):
|
246 |
+
decoded_token = token[1:] # Remove Ġ
|
247 |
+
elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
|
248 |
+
decoded_token = token
|
249 |
+
else:
|
250 |
+
decoded_token = token
|
251 |
+
hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
|
252 |
+
else:
|
253 |
+
# Fallback
|
254 |
+
if token.startswith('Ġ'):
|
255 |
+
decoded_token = token[1:]
|
256 |
+
elif token in ['[CLS]', '[SEP]', '[PAD]', '[UNK]']:
|
257 |
+
decoded_token = token
|
258 |
+
else:
|
259 |
+
decoded_token = token
|
260 |
+
hate_tokens.append((j, decoded_token, self.bio_names[pred.item()]))
|
261 |
+
|
262 |
+
# Determine label
|
263 |
+
label = self.label_names[sentence_pred]
|
264 |
+
# If hate_tokens contain B-HARD, I-HARD, increase label to L2_hate
|
265 |
+
if any(bio_label in ["B-HARD", "I-HARD"] for _, _, bio_label in hate_tokens):
|
266 |
+
label = "L2_hate"
|
267 |
+
# Construct result message
|
268 |
+
result_msg = f"Detection result: {label}\nConfidence: {sentence_prob:.2f}"
|
269 |
+
if hate_tokens:
|
270 |
+
result_msg += f"\nIdentified hate/aggressive expressions: {hate_tokens}"
|
271 |
+
mitigation = "Performed only detection."
|
272 |
+
debug_info = {
|
273 |
+
'sentence_logits': sentence_logits,
|
274 |
+
'sentence_probs': sentence_probs,
|
275 |
+
'sentence_pred': sentence_pred,
|
276 |
+
'label': label,
|
277 |
+
'confidence': sentence_prob,
|
278 |
+
'hate_tokens': hate_tokens
|
279 |
+
}
|
280 |
+
return result_msg, mitigation, debug_info
|
281 |
+
|
282 |
+
def _unguided_mitigation(self, text):
|
283 |
+
"""Unguided Mode: Only Using Generation"""
|
284 |
+
try:
|
285 |
+
# Blossom LLM prompt
|
286 |
+
prompt = f"""Please remove hate speech or aggressive expressions from the following sentence, while maintaining the original intent (criticism, complaint, opinion, etc.).
|
287 |
+
|
288 |
+
Original: {text}
|
289 |
+
|
290 |
+
Mitigated sentence:"""
|
291 |
+
|
292 |
+
# LLM inference
|
293 |
+
inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
|
294 |
+
|
295 |
+
with torch.no_grad():
|
296 |
+
outputs = self.llm_model.generate(
|
297 |
+
**inputs,
|
298 |
+
do_sample=True,
|
299 |
+
top_k=50,
|
300 |
+
top_p=0.9,
|
301 |
+
max_new_tokens=300,
|
302 |
+
pad_token_id=self.llm_tokenizer.pad_token_id,
|
303 |
+
eos_token_id=self.llm_tokenizer.eos_token_id
|
304 |
+
)
|
305 |
+
|
306 |
+
# Decode result
|
307 |
+
full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
308 |
+
|
309 |
+
# Remove prompt part and extract mitigated sentence
|
310 |
+
mitigated_text = full_response.replace(prompt, "").strip()
|
311 |
+
|
312 |
+
# Handle truncated sentences
|
313 |
+
if len(mitigated_text) < 10: # Too short, use original response
|
314 |
+
mitigated_text = full_response
|
315 |
+
|
316 |
+
# Prevent repetitive output: extract only the first mitigated sentence
|
317 |
+
if "Mitigated sentence:" in mitigated_text:
|
318 |
+
mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
|
319 |
+
|
320 |
+
# Use only the first meaningful line if multiple lines
|
321 |
+
lines = mitigated_text.split('\n')
|
322 |
+
clean_lines = []
|
323 |
+
for line in lines:
|
324 |
+
line = line.strip()
|
325 |
+
if line and not line.startswith('**') and not line.startswith('Original:'):
|
326 |
+
clean_lines.append(line)
|
327 |
+
|
328 |
+
if clean_lines:
|
329 |
+
mitigated_text = clean_lines[0]
|
330 |
+
|
331 |
+
# Result message
|
332 |
+
result_msg = f"🤖 **Blossom LLM Mitigation Result**\n\n"
|
333 |
+
result_msg += f"**Original:** {text}\n\n"
|
334 |
+
result_msg += f"**Mitigated Sentence:** {mitigated_text}"
|
335 |
+
|
336 |
+
# Mitigation info
|
337 |
+
mitigation = "**Unguided Mode:** Blossom LLM detected and mitigated harmful expressions autonomously."
|
338 |
+
|
339 |
+
return result_msg, mitigation
|
340 |
+
|
341 |
+
except Exception as e:
|
342 |
+
error_msg = f"❌ **Blossom LLM Error**\n\nError occurred: {str(e)}"
|
343 |
+
return error_msg, "An error occurred during LLM processing."
|
344 |
+
|
345 |
+
def _guided_mitigation(self, text):
|
346 |
+
"""Guided Mode: Mitigate based on KcELECTRA detection result using Blossom LLM"""
|
347 |
+
try:
|
348 |
+
# First, perform detection with KcELECTRA
|
349 |
+
detection_result, _, debug_info = self._detection_only(text)
|
350 |
+
label = debug_info.get('label', 'normal')
|
351 |
+
hate_tokens = debug_info.get('hate_tokens', [])
|
352 |
+
|
353 |
+
# Construct Blossom LLM prompt
|
354 |
+
if label == "normal":
|
355 |
+
prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
|
356 |
+
else:
|
357 |
+
label_desc = {
|
358 |
+
"offensive": "Aggressive",
|
359 |
+
"L1_hate": "Mild Hate",
|
360 |
+
"L2_hate": "Severe Hate"
|
361 |
+
}
|
362 |
+
hate_tokens_str = ""
|
363 |
+
if hate_tokens:
|
364 |
+
hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"• {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
|
365 |
+
prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nPlease remove hate speech or aggressive expressions, while maintaining the original intent (criticism, complaint, opinion, etc.).\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\n[Important] All offensive, derogatory, and explicit hate expressions (e.g., 씨발, 좆, 병신) must be deleted.\n\nMitigated sentence:"""
|
366 |
+
# LLM inference
|
367 |
+
inputs = self.llm_tokenizer(prompt, return_tensors="pt").to(self.llm_model.device)
|
368 |
+
with torch.no_grad():
|
369 |
+
outputs = self.llm_model.generate(
|
370 |
+
**inputs,
|
371 |
+
do_sample=True,
|
372 |
+
top_k=50,
|
373 |
+
top_p=0.9,
|
374 |
+
max_new_tokens=300,
|
375 |
+
pad_token_id=self.llm_tokenizer.pad_token_id,
|
376 |
+
eos_token_id=self.llm_tokenizer.eos_token_id
|
377 |
+
)
|
378 |
+
full_response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
379 |
+
mitigated_text = full_response.replace(prompt, "").strip()
|
380 |
+
if len(mitigated_text) < 10:
|
381 |
+
mitigated_text = full_response
|
382 |
+
if "Mitigated sentence:" in mitigated_text:
|
383 |
+
mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
|
384 |
+
lines = mitigated_text.split('\n')
|
385 |
+
clean_lines = []
|
386 |
+
for line in lines:
|
387 |
+
line = line.strip()
|
388 |
+
if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
|
389 |
+
clean_lines.append(line)
|
390 |
+
if clean_lines:
|
391 |
+
mitigated_text = clean_lines[0]
|
392 |
+
result_msg = f"🎯 **Guided Mitigation Result**\n\n"
|
393 |
+
result_msg += f"**KcELECTRA Detection Result:**\n{detection_result}\n\n"
|
394 |
+
result_msg += f"**Blossom LLM Mitigation Result:**\n{mitigated_text}"
|
395 |
+
mitigation = "**Guided Mode:** Blossom LLM performed specific mitigation based on KcELECTRA's detection information."
|
396 |
+
return result_msg, mitigation
|
397 |
+
except Exception as e:
|
398 |
+
error_msg = f"❌ **Guided Mitigation Error**\n\nError occurred: {str(e)}"
|
399 |
+
return error_msg, "An error occurred during guided mitigation processing."
|
400 |
+
|
401 |
+
def _guided_reflect_mitigation(self, text):
|
402 |
+
"""Guided+Reflect Mode: iterative refinement + critic evaluation"""
|
403 |
+
try:
|
404 |
+
detection_result, _, debug_info = self._detection_only(text)
|
405 |
+
label = debug_info.get('label', 'normal')
|
406 |
+
hate_tokens = debug_info.get('hate_tokens', [])
|
407 |
+
# Step 1: Initial mitigation
|
408 |
+
if label == "normal":
|
409 |
+
initial_prompt = f"""The following sentence is classified as a normal sentence. Please improve it by expressing it more politely and respectfully, while maintaining the original intent.\n\nOriginal: {text}\n\nImproved sentence:"""
|
410 |
+
else:
|
411 |
+
label_desc = {
|
412 |
+
"offensive": "Aggressive",
|
413 |
+
"L1_hate": "Mild Hate",
|
414 |
+
"L2_hate": "Severe Hate"
|
415 |
+
}
|
416 |
+
hate_tokens_str = ""
|
417 |
+
if hate_tokens:
|
418 |
+
hate_tokens_str = "\nExpressions causing issues:\n" + "\n".join([f"• {token} ({bio_label})" for _, token, bio_label in hate_tokens[:5]])
|
419 |
+
initial_prompt = f"""The following sentence is classified as {label_desc.get(label, "harmful")} expression. \nExpressions containing offensive words (e.g., 좃, 씨발, 병신) must be deleted.\nOther aggressive or inappropriate expressions should be mitigated by expressing them more politely and inclusively.\n\nOriginal: {text}\nClassification: {label_desc.get(label, "harmful")} expression\n{hate_tokens_str}\n\nMitigated sentence:"""
|
420 |
+
# Iterative mitigation and evaluation
|
421 |
+
max_iter = 5
|
422 |
+
metrics_history = []
|
423 |
+
best_candidate = None
|
424 |
+
best_score = -float('inf')
|
425 |
+
current_input = text
|
426 |
+
for i in range(max_iter):
|
427 |
+
# Generate candidate
|
428 |
+
inputs = self.llm_tokenizer(initial_prompt, return_tensors="pt").to(self.llm_model.device)
|
429 |
+
with torch.no_grad():
|
430 |
+
outputs = self.llm_model.generate(
|
431 |
+
**inputs,
|
432 |
+
do_sample=True,
|
433 |
+
top_k=50,
|
434 |
+
top_p=0.9,
|
435 |
+
max_new_tokens=300,
|
436 |
+
pad_token_id=self.llm_tokenizer.pad_token_id,
|
437 |
+
eos_token_id=self.llm_tokenizer.eos_token_id
|
438 |
+
)
|
439 |
+
candidate = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
440 |
+
mitigated_text = candidate.replace(initial_prompt, "").strip()
|
441 |
+
if len(mitigated_text) < 10:
|
442 |
+
mitigated_text = candidate
|
443 |
+
if "Mitigated sentence:" in mitigated_text:
|
444 |
+
mitigated_text = mitigated_text.split("Mitigated sentence:")[-1].strip()
|
445 |
+
lines = mitigated_text.split('\n')
|
446 |
+
clean_lines = []
|
447 |
+
for line in lines:
|
448 |
+
line = line.strip()
|
449 |
+
if line and not line.startswith('**') and not line.startswith('Original:') and not line.startswith('Classification:'):
|
450 |
+
clean_lines.append(line)
|
451 |
+
if clean_lines:
|
452 |
+
mitigated_text = clean_lines[0]
|
453 |
+
# Exclude candidates containing offensive words
|
454 |
+
if contains_badword(mitigated_text):
|
455 |
+
continue
|
456 |
+
# Evaluation
|
457 |
+
toxicity = calc_toxicity_reduction(text, mitigated_text, self.model, self.tokenizer)
|
458 |
+
bertscore = calc_bertscore(text, mitigated_text)
|
459 |
+
ppl = calc_ppl(mitigated_text)
|
460 |
+
metrics_history.append({'iteration': i+1, 'candidate': mitigated_text, 'toxicity': toxicity, 'bertscore': bertscore, 'ppl': ppl})
|
461 |
+
# Simple combined score (weight adjustment possible)
|
462 |
+
total_score = toxicity + bertscore - ppl * 0.01
|
463 |
+
if total_score > best_score:
|
464 |
+
best_score = total_score
|
465 |
+
best_candidate = mitigated_text
|
466 |
+
# Early termination criteria (e.g., toxicity>0.3, bertscore>0.7, ppl<100)
|
467 |
+
if toxicity > 0.3 and bertscore > 0.7 and ppl < 100:
|
468 |
+
break
|
469 |
+
# Log output
|
470 |
+
iter_log_str = ""
|
471 |
+
for log in metrics_history:
|
472 |
+
iter_log_str += f"\nIteration {log['iteration']}:\n- Candidate: {log['candidate']}\n- Toxicity reduction: {log['toxicity']}, bertscore: {log['bertscore']}, ppl: {log['ppl']}"
|
473 |
+
# Result message
|
474 |
+
result_msg = f"🔄 **Guided+Reflect Mitigation Result**\n\n"
|
475 |
+
result_msg += f"**Detection Result:**\n{detection_result}\n\n"
|
476 |
+
result_msg += f"**Iterative Mitigation Log:**{iter_log_str}\n\n"
|
477 |
+
result_msg += f"**Best Mitigation:** {best_candidate}"
|
478 |
+
mitigation = "**Guided+Reflect Mode:** Selected the optimal candidate after iterative mitigation and evaluation (maximum 5 iterations)."
|
479 |
+
return result_msg, mitigation
|
480 |
+
except Exception as e:
|
481 |
+
error_msg = f"❌ **Guided+Reflect Mitigation Error**\n\nError occurred: {str(e)}"
|
482 |
+
return error_msg, "An error occurred during guided+reflect mitigation processing."
|
483 |
+
|
484 |
+
def _suggest_mitigation(self, label, confidence, hate_tokens):
|
485 |
+
"""Suggest mitigation for hate speech expressions"""
|
486 |
+
if label == "normal":
|
487 |
+
return "✅ **Mitigation Suggestion**: This sentence does not require correction."
|
488 |
+
|
489 |
+
mitigation = f"**🔧 Mitigation Suggestion for Hate Speech:**\n\n"
|
490 |
+
|
491 |
+
if label == "offensive":
|
492 |
+
mitigation += "**Aggressive Expression Mitigation Options:**\n"
|
493 |
+
mitigation += "• Try to change aggressive expressions to more polite expressions\n"
|
494 |
+
mitigation += "• Use objective expressions instead of emotional expressions\n"
|
495 |
+
mitigation += "• Reconstruct with a mind to be considerate\n"
|
496 |
+
mitigation += "• When criticizing, provide specific and constructive feedback"
|
497 |
+
elif label == "L1_hate":
|
498 |
+
mitigation += "**Implicit Hate Expression Mitigation Options:**\n"
|
499 |
+
mitigation += "• Remove expressions that discriminate or show prejudice\n"
|
500 |
+
mitigation += "• Avoid generalizing about specific groups\n"
|
501 |
+
mitigation += "• Use more inclusive and respectful expressions\n"
|
502 |
+
mitigation += "• Change to expressions that acknowledge diversity"
|
503 |
+
else: # L2_hate
|
504 |
+
mitigation += "**Explicit Hate Expression Mitigation Options:**\n"
|
505 |
+
mitigation += "• Completely remove severe hate expressions\n"
|
506 |
+
mitigation += "• Do not use violent or threatening expressions\n"
|
507 |
+
mitigation += "• Use expressions that respect everyone's dignity\n"
|
508 |
+
mitigation += "• Change to expressions that discriminate or promote hate\n"
|
509 |
+
mitigation += "• If necessary, seek professional help"
|
510 |
+
|
511 |
+
return mitigation
|
512 |
+
|
513 |
+
def contains_badword(text):
|
514 |
+
badwords = ["좃", "씨발", "병신", "개새끼", "염병", "좆", "ㅅㅂ", "ㅄ", "ㅂㅅ", "ㅗ", "ㅉ"]
|
515 |
+
return any(bad in text for bad in badwords)
|
516 |
+
|
517 |
+
# Service initialization
|
518 |
+
service = HateSpeechDetectorService()
|
519 |
+
|
520 |
+
# Gradio interface
|
521 |
+
def create_demo():
|
522 |
+
with gr.Blocks(
|
523 |
+
title="Korean Hate Speech Detection and Mitigation System",
|
524 |
+
theme=gr.themes.Soft(),
|
525 |
+
css="""
|
526 |
+
.gradio-container {
|
527 |
+
max-width: 800px;
|
528 |
+
margin: 0 auto;
|
529 |
+
}
|
530 |
+
.result-box {
|
531 |
+
border-radius: 10px;
|
532 |
+
padding: 15px;
|
533 |
+
margin: 10px 0;
|
534 |
+
}
|
535 |
+
.normal { background-color: #d4edda; border: 1px solid #c3e6cb; }
|
536 |
+
.offensive { background-color: #fff3cd; border: 1px solid #ffeaa7; }
|
537 |
+
.hate { background-color: #f8d7da; border: 1px solid #f5c6cb; }
|
538 |
+
"""
|
539 |
+
) as demo:
|
540 |
+
gr.Markdown("""
|
541 |
+
# Korean Hate Speech Detection and Mitigation System
|
542 |
+
|
543 |
+
This system detects hate speech in Korean text and provides mitigation suggestions.
|
544 |
+
|
545 |
+
|
546 |
+
**🟢 Normal**:
|
547 |
+
- It is a normal sentence.
|
548 |
+
|
549 |
+
**🟡 Offensive**
|
550 |
+
|
551 |
+
- For example: "Don't say such a stupid thing", "How can you do such a stupid thing"
|
552 |
+
|
553 |
+
**🟠 L1_hate (Implicit Hate)**: Mild hate expression
|
554 |
+
- **Implicit hate expression** for protected attribute groups
|
555 |
+
- For example: "Those people are all the same", "Prejudicial expression towards a specific group"
|
556 |
+
|
557 |
+
**🔴 L2_hate (Explicit Hate)**: Severe hate expression
|
558 |
+
- **Explicit hate expression** for protected attribute groups
|
559 |
+
|
560 |
+
**🤖 Mitigation Mode:**
|
561 |
+
- 🔍 **Detection Only**: Hate Speech Detection Only
|
562 |
+
- 🎯 **Guided**: Guided Mitigation
|
563 |
+
- 🔄 **Guided+Reflect**: After Guided Mitigation, Iterative Refinement
|
564 |
+
- 🤖 **Unguided**: LLM generates text without any guidance
|
565 |
+
""")
|
566 |
+
|
567 |
+
with gr.Row():
|
568 |
+
with gr.Column(scale=2):
|
569 |
+
input_text = gr.Textbox(
|
570 |
+
label="Enter text to detect hate speech & mitigate",
|
571 |
+
lines=3
|
572 |
+
)
|
573 |
+
|
574 |
+
strategy = gr.Radio(
|
575 |
+
["Detection Only", "Guided", "Guided+Reflect", "Unguided"],
|
576 |
+
value="Detection Only",
|
577 |
+
label="Select Mitigation Mode",
|
578 |
+
container=True
|
579 |
+
)
|
580 |
+
|
581 |
+
analyze_btn = gr.Button("🔍 Detect & Mitigate", variant="primary", size="lg")
|
582 |
+
|
583 |
+
# with gr.Column(scale=1):
|
584 |
+
# gr.Markdown("""
|
585 |
+
# **🧪 Test Examples:**
|
586 |
+
|
587 |
+
# **🟢 Normal:**
|
588 |
+
# - "Hello, today's weather is nice."
|
589 |
+
# - "This movie was really fun."
|
590 |
+
|
591 |
+
# **🟡 Offensive:**
|
592 |
+
# - "How can you do such a stupid thing"
|
593 |
+
# - "Don't say such a stupid thing"
|
594 |
+
|
595 |
+
# **🟠 L1_hate (Implicit):**
|
596 |
+
# - "Those people are all the same"
|
597 |
+
# - "Prejudicial expression towards a specific group"
|
598 |
+
|
599 |
+
# **🔴 L2_hate (Explicit):**
|
600 |
+
# - "All women are useless"
|
601 |
+
# - "People with disabilities are a burden to society"
|
602 |
+
# """)
|
603 |
+
|
604 |
+
with gr.Row():
|
605 |
+
with gr.Column():
|
606 |
+
result_output = gr.Markdown(
|
607 |
+
label="Mitigation Button",
|
608 |
+
value="Input text and click the above button."
|
609 |
+
)
|
610 |
+
|
611 |
+
with gr.Column():
|
612 |
+
mitigation_output = gr.Markdown(
|
613 |
+
label="Mitigation Suggestion",
|
614 |
+
value="Based on the analysis result, mitigation suggestions will be provided."
|
615 |
+
)
|
616 |
+
|
617 |
+
# Event handlers
|
618 |
+
analyze_btn.click(
|
619 |
+
fn=service.detect_hate_speech,
|
620 |
+
inputs=[input_text, strategy],
|
621 |
+
outputs=[result_output, mitigation_output]
|
622 |
+
)
|
623 |
+
|
624 |
+
# Allow analysis via Enter key
|
625 |
+
input_text.submit(
|
626 |
+
fn=service.detect_hate_speech,
|
627 |
+
inputs=[input_text, strategy],
|
628 |
+
outputs=[result_output, mitigation_output]
|
629 |
+
)
|
630 |
+
|
631 |
+
# gr.Markdown("""
|
632 |
+
# ---
|
633 |
+
# **Model Information:**
|
634 |
+
# - Detection Model: KcELECTRA-base (Validation Accuracy: 67.67%)
|
635 |
+
# - Mitigation Model: Blossom LLM (llama-3.2-Korean-Bllossom-3B)
|
636 |
+
# - Training Data: K-HATERS Dataset
|
637 |
+
# """)
|
638 |
+
|
639 |
+
return demo
|
640 |
+
|
641 |
+
if __name__ == "__main__":
|
642 |
+
demo = create_demo()
|
643 |
+
demo.launch(
|
644 |
+
server_name="0.0.0.0",
|
645 |
+
server_port=7863,
|
646 |
+
share=True,
|
647 |
+
show_error=True
|
648 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.44.0
|
2 |
+
torch>=2.0.0
|
3 |
+
transformers>=4.30.0
|
4 |
+
bert-score>=0.3.13
|
5 |
+
numpy>=1.21.0
|
6 |
+
scikit-learn>=1.0.0
|
7 |
+
accelerate>=0.20.0
|
8 |
+
TorchCRF==1.1.0
|
9 |
+
|