Spaces:
Sleeping
Sleeping
import json | |
import tempfile | |
import re | |
import os | |
from collections import defaultdict | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
pipeline, | |
) | |
import torch | |
from pyvis.network import Network | |
# ------------------------------- | |
# 模型配置 | |
# ------------------------------- | |
# 使用环境变量配置模型名称,便于在Hugging Face上部署时修改 | |
NER_MODEL_NAME = os.environ.get("NER_MODEL_NAME", "ckiplab/bert-base-chinese-ner") | |
REL_MODEL_NAME = os.environ.get("REL_MODEL_NAME", "hfl/chinese-roberta-wwm-ext") | |
# ------------------------------- | |
# 实体识别模型(NER) | |
# ------------------------------- | |
try: | |
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_NAME) | |
ner_model = AutoModelForSequenceClassification.from_pretrained(NER_MODEL_NAME) | |
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") | |
except Exception as e: | |
print(f"NER模型加载失败: {e}") | |
# 可以添加备选方案或错误处理逻辑 | |
# ------------------------------- | |
# 人物关系分类模型(使用 RoBERTa) | |
# ------------------------------- | |
try: | |
rel_tokenizer = AutoTokenizer.from_pretrained(REL_MODEL_NAME) | |
rel_model = AutoModelForSequenceClassification.from_pretrained( | |
REL_MODEL_NAME, | |
num_labels=6, # 确保标签数量匹配 | |
id2label={0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他"}, | |
label2id={"夫妻": 0, "父子": 1, "朋友": 2, "师生": 3, "同事": 4, "其他": 5} | |
) | |
rel_model.eval() | |
except Exception as e: | |
print(f"关系分类模型加载失败: {e}") | |
# 可以添加备选方案或错误处理逻辑 | |
# 关系分类的标签映射 | |
relation_id2label = { | |
0: "夫妻", 1: "父子", 2: "朋友", 3: "师生", 4: "同事", 5: "其他" | |
} | |
# 法律风险分析的标签映射 | |
legal_id2label = { | |
0: "无违法", 1: "赌博", 2: "毒品", 3: "色情", 4: "诈骗", 5: "暴力" | |
} | |
# ------------------------------- | |
# 聊天输入解析 | |
# ------------------------------- | |
def parse_input_file(file): | |
"""解析聊天文件,支持JSON和TXT格式""" | |
try: | |
filename = file.name | |
if filename.endswith(".json"): | |
return json.load(file) | |
elif filename.endswith(".txt"): | |
content = file.read().decode("utf-8") | |
lines = content.strip().splitlines() | |
chat_data = [] | |
for line in lines: | |
match = re.match(r"(\d{4}-\d{2}-\d{2}.*?) (.*?): (.*)", line) | |
if match: | |
_, sender, message = match.groups() | |
chat_data.append({"sender": sender, "message": message}) | |
return chat_data | |
else: | |
raise ValueError("不支持的文件格式,请上传JSON或TXT文件") | |
except Exception as e: | |
print(f"文件解析错误: {e}") | |
raise | |
# ------------------------------- | |
# 实体提取函数 | |
# ------------------------------- | |
def extract_entities(text): | |
"""从文本中提取人物实体""" | |
try: | |
results = ner_pipeline(text) | |
people = set() | |
for r in results: | |
if r["entity_group"] == "PER": | |
people.add(r["word"]) | |
return list(people) | |
except Exception as e: | |
print(f"实体提取错误: {e}") | |
return [] | |
# ------------------------------- | |
# 关系抽取函数(共现 + BERT 分类) | |
# ------------------------------- | |
def extract_relations(chat_data, entities): | |
"""分析人物之间的关系""" | |
relations = defaultdict(lambda: defaultdict(lambda: {"count": 0, "contexts": []})) | |
for entry in chat_data: | |
msg = entry["message"] | |
found = [e for e in entities if e in msg] | |
for i in range(len(found)): | |
for j in range(i + 1, len(found)): | |
e1, e2 = found[i], found[j] | |
relations[e1][e2]["count"] += 1 | |
relations[e1][e2]["contexts"].append(msg) | |
relations[e2][e1]["count"] += 1 | |
relations[e2][e1]["contexts"].append(msg) | |
edges = [] | |
for e1 in relations: | |
for e2 in relations[e1]: | |
if e1 < e2: | |
context_text = " ".join(relations[e1][e2]["contexts"]) | |
# 截断过长的文本 | |
max_context_length = 500 # 根据需要调整 | |
if len(context_text) > max_context_length: | |
context_text = context_text[:max_context_length] + "..." | |
label = classify_relation_bert(e1, e2, context_text) | |
edges.append((e1, e2, relations[e1][e2]["count"], label)) | |
return edges | |
# ------------------------------- | |
# 法律风险分析(黄赌毒等)函数 | |
# ------------------------------- | |
def classify_relation_bert(e1, e2, context): | |
"""使用BERT模型分析人物关系""" | |
try: | |
prompt = f"{e1}和{e2}的关系是?{context}" | |
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
logits = rel_model(**inputs).logits | |
pred = torch.argmax(logits, dim=1).item() | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
confidence = probs[0, pred].item() | |
return f"{relation_id2label[pred]}(置信度 {confidence:.2f})" | |
except Exception as e: | |
print(f"关系分类错误: {e}") | |
return "其他(置信度 0.00)" | |
def classify_illegal_behavior(chat_context): | |
"""分析聊天内容中的法律风险""" | |
try: | |
prompt = f"请分析以下聊天记录,判断是否涉及以下违法行为:赌博、毒品、色情、诈骗、暴力行为。\n聊天内容:{chat_context}\n请回答:" | |
inputs = rel_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
logits = rel_model(**inputs).logits | |
pred = torch.argmax(logits, dim=1).item() | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
confidence = probs[0, pred].item() | |
return f"违法行为判断结果:{legal_id2label.get(pred, '未知')}(置信度 {confidence:.2f})" | |
except Exception as e: | |
print(f"法律风险分析错误: {e}") | |
return "违法行为判断结果:未知(置信度 0.00)" | |
# ------------------------------- | |
# 图谱绘制 | |
# ------------------------------- | |
def draw_graph(entities, relations): | |
"""生成人物关系图谱""" | |
try: | |
g = Network(height="600px", width="100%", notebook=False) | |
g.barnes_hut() | |
for ent in entities: | |
g.add_node(ent, label=ent) | |
for e1, e2, weight, label in relations: | |
g.add_edge(e1, e2, value=weight, title=f"{label}(互动{weight}次)", label=label) | |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") | |
g.show(tmp_file.name) | |
with open(tmp_file.name, 'r', encoding='utf-8') as f: | |
return f.read() | |
except Exception as e: | |
print(f"图谱绘制错误: {e}") | |
return "<h3>图谱生成失败</h3><p>请检查输入数据是否有效</p>" | |
# ------------------------------- | |
# 主流程函数 | |
# ------------------------------- | |
def analyze_chat(file): | |
"""分析聊天记录的主函数""" | |
if file is None: | |
return "请上传聊天文件", "", "", "" | |
try: | |
content = parse_input_file(file) | |
except Exception as e: | |
return f"读取文件失败: {e}", "", "", "" | |
text = "\n".join([entry["sender"] + ": " + entry["message"] for entry in content]) | |
entities = extract_entities(text) | |
if not entities: | |
return "未识别到任何人物实体", "", "", "" | |
relations = extract_relations(content, entities) | |
if not relations: | |
return "未发现人物之间的关系", "", "", "" | |
# 法律风险分析 | |
illegal_behavior_results = [classify_illegal_behavior(msg["message"]) for msg in content] | |
graph_html = draw_graph(entities, relations) | |
return str(entities), str(relations), graph_html, "\n".join(illegal_behavior_results) |