Spaces:
Sleeping
Sleeping
File size: 8,112 Bytes
3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 4affd42 3d9242d 950dc1a 3d9242d 6a568bb 3d9242d 6a568bb 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 6a568bb 3d9242d e4ec800 950dc1a e4ec800 950dc1a e4ec800 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d 950dc1a 3d9242d e4ec800 3d9242d 950dc1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import json
import tempfile
import re
import os
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline,
)
import torch
from pyvis.network import Network
# -------------------------------
# 模型配置
# -------------------------------
# 使用环境变量配置模型名称,便于在Hugging Face上部署时修改
NER_MODEL_NAME = os.environ.get("NER_MODEL_NAME", "ckiplab/bert-base-chinese-ner")
REL_MODEL_NAME = os.environ.get("REL_MODEL_NAME", "hfl/chinese-roberta-wwm-ext")
# -------------------------------
# 实体识别模型(NER)
# -------------------------------
try:
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME)
ner_model = AutoModelForSequenceClassification.from_pretrained(NER_MODEL_NAME)
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
except Exception as e:
print(f"NER模型加载失败: {e}")
# 可以添加备选方案或错误处理逻辑
# -------------------------------
# 人物关系分类模型(使用 RoBERTa)
# -------------------------------
try:
rel_tokenizer = AutoTokenizer.from_pretrained(REL_MODEL_NAME)
rel_model = AutoModelForSequenceClassification.from_pretrained(
REL_MODEL_NAME,
num_labels=6, # 确保标签数量匹配
id2label={0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"},
label2id={"夫妻": 0, "父子": 1, "朋友": 2, "师生": 3, "同事": 4, "其他": 5}
)
rel_model.eval()
except Exception as e:
print(f"关系分类模型加载失败: {e}")
# 可以添加备选方案或错误处理逻辑
# 关系分类的标签映射
relation_id2label = {
0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"
}
# 法律风险分析的标签映射
legal_id2label = {
0: "无违法", 1: "赌博", 2: "毒品", 3: "色情", 4: "诈骗", 5: "暴力"
}
# -------------------------------
# 聊天输入解析
# -------------------------------
def parse_input_file(file):
"""解析聊天文件,支持JSON和TXT格式"""
try:
filename = file.name
if filename.endswith(".json"):
return json.load(file)
elif filename.endswith(".txt"):
content = file.read().decode("utf-8")
lines = content.strip().splitlines()
chat_data = []
for line in lines:
match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line)
if match:
_, sender, message = match.groups()
chat_data.append({"sender": sender, "message": message})
return chat_data
else:
raise ValueError("不支持的文件格式,请上传JSON或TXT文件")
except Exception as e:
print(f"文件解析错误: {e}")
raise
# -------------------------------
# 实体提取函数
# -------------------------------
def extract_entities(text):
"""从文本中提取人物实体"""
try:
results = ner_pipeline(text)
people = set()
for r in results:
if r["entity_group"] == "PER":
people.add(r["word"])
return list(people)
except Exception as e:
print(f"实体提取错误: {e}")
return []
# -------------------------------
# 关系抽取函数(共现 + BERT 分类)
# -------------------------------
def extract_relations(chat_data, entities):
"""分析人物之间的关系"""
relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []}))
for entry in chat_data:
msg = entry["message"]
found = [e for e in entities if e in msg]
for i in range(len(found)):
for j in range(i + 1, len(found)):
e1, e2 = found[i], found[j]
relations[e1][e2]["count"] += 1
relations[e1][e2]["contexts"].append(msg)
relations[e2][e1]["count"] += 1
relations[e2][e1]["contexts"].append(msg)
edges = []
for e1 in relations:
for e2 in relations[e1]:
if e1 < e2:
context_text = " ".join(relations[e1][e2]["contexts"])
# 截断过长的文本
max_context_length = 500 # 根据需要调整
if len(context_text) > max_context_length:
context_text = context_text[:max_context_length] + "..."
label = classify_relation_bert(e1, e2, context_text)
edges.append((e1, e2, relations[e1][e2]["count"], label))
return edges
# -------------------------------
# 法律风险分析(黄赌毒等)函数
# -------------------------------
def classify_relation_bert(e1, e2, context):
"""使用BERT模型分析人物关系"""
try:
prompt = f"{e1}和{e2}的关系是?{context}"
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = rel_model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
probs = torch.nn.functional.softmax(logits, dim=1)
confidence = probs[0, pred].item()
return f"{relation_id2label[pred]}(置信度 {confidence:.2f})"
except Exception as e:
print(f"关系分类错误: {e}")
return "其他(置信度 0.00)"
def classify_illegal_behavior(chat_context):
"""分析聊天内容中的法律风险"""
try:
prompt = f"请分析以下聊天记录,判断是否涉及以下违法行为:赌博、毒品、色情、诈骗、暴力行为。\n聊天内容:{chat_context}\n请回答:"
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = rel_model(**inputs).logits
pred = torch.argmax(logits, dim=1).item()
probs = torch.nn.functional.softmax(logits, dim=1)
confidence = probs[0, pred].item()
return f"违法行为判断结果:{legal_id2label.get(pred, '未知')}(置信度 {confidence:.2f})"
except Exception as e:
print(f"法律风险分析错误: {e}")
return "违法行为判断结果:未知(置信度 0.00)"
# -------------------------------
# 图谱绘制
# -------------------------------
def draw_graph(entities, relations):
"""生成人物关系图谱"""
try:
g = Network(height="600px", width="100%", notebook=False)
g.barnes_hut()
for ent in entities:
g.add_node(ent, label=ent)
for e1, e2, weight, label in relations:
g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
g.show(tmp_file.name)
with open(tmp_file.name, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"图谱绘制错误: {e}")
return "<h3>图谱生成失败</h3><p>请检查输入数据是否有效</p>"
# -------------------------------
# 主流程函数
# -------------------------------
def analyze_chat(file):
"""分析聊天记录的主函数"""
if file is None:
return "请上传聊天文件", "", "", ""
try:
content = parse_input_file(file)
except Exception as e:
return f"读取文件失败: {e}", "", "", ""
text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content])
entities = extract_entities(text)
if not entities:
return "未识别到任何人物实体", "", "", ""
relations = extract_relations(content, entities)
if not relations:
return "未发现人物之间的关系", "", "", ""
# 法律风险分析
illegal_behavior_results = [classify_illegal_behavior(msg["message"]) for msg in content]
graph_html = draw_graph(entities, relations)
return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results) |