File size: 5,458 Bytes
5ddcb1d
 
8a11c8b
d5d1f19
 
 
 
 
 
 
 
 
 
 
 
 
963a084
d5d1f19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733a440
 
d5d1f19
733a440
5985252
d5d1f19
733a440
 
d5d1f19
 
5985252
d5d1f19
 
 
 
 
 
 
 
 
5985252
733a440
d5d1f19
5985252
d5d1f19
5ddcb1d
733a440
d5d1f19
 
5985252
d5d1f19
 
 
733a440
 
 
8a11c8b
 
 
d5d1f19
 
 
733a440
8a11c8b
5ddcb1d
d5d1f19
 
5985252
733a440
d5d1f19
733a440
5ddcb1d
d5d1f19
733a440
 
d5d1f19
733a440
d5d1f19
733a440
d5d1f19
 
 
 
 
 
 
733a440
d5d1f19
733a440
 
d5d1f19
733a440
d5d1f19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import pandas as pd
import json
import io
import os
import random
from collections import defaultdict

from sentence_transformers import SentenceTransformer
import hdbscan
from sklearn.metrics import silhouette_score, davies_bouldin_score
import numpy as np
import umap
from sklearn.preprocessing import MinMaxScaler

# 加载模型,放到全局避免重复加载
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def color_for_label(label):
    try:
        label_int = int(label)
    except:
        label_int = -1
    if label_int < 0:
        return "rgb(150,150,150)"  # 噪声点
    random.seed(label_int + 1000)
    return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})"

def cluster_sentences(sentences):
    embeddings = model.encode(sentences)
    clusterer = hdbscan.HDBSCAN(min_cluster_size=2, metric='euclidean')
    labels = clusterer.fit_predict(embeddings)

    valid_idxs = labels != -1
    if np.sum(valid_idxs) > 1:
        silhouette = silhouette_score(embeddings[valid_idxs], labels[valid_idxs])
        db = davies_bouldin_score(embeddings[valid_idxs], labels[valid_idxs])
    else:
        silhouette, db = -1, -1

    return labels, embeddings, {"silhouette": silhouette, "db": db}

def generate_force_graph(sentences, labels):
    nodes = []
    links = []
    label_map = defaultdict(list)

    for i, (s, l) in enumerate(zip(sentences, labels)):
        color = color_for_label(l)
        nodes.append({"name": s, "symbolSize": 10, "category": int(l) if l >=0 else 0, "itemStyle": {"color": color}})
        label_map[l].append(i)

    for group in label_map.values():
        max_edges_per_node = 10
        for i in group:
            connected = 0
            for j in group:
                if i < j:
                    links.append({"source": sentences[i], "target": sentences[j]})
                    connected += 1
                    if connected >= max_edges_per_node:
                        break
    return {"type": "force", "nodes": nodes, "links": links}

def generate_bubble_chart(sentences, labels):
    counts = defaultdict(int)
    for l in labels:
        counts[l] += 1
    data = [{"name": f"簇{l}" if l >=0 else "噪声", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()]
    return {"type": "bubble", "series": [{"type": "scatter", "data": data}]}

def generate_umap_plot(embeddings, labels):
    reducer = umap.UMAP(n_components=2, random_state=42)
    umap_emb = reducer.fit_transform(embeddings)
    scaled = MinMaxScaler().fit_transform(umap_emb)
    data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}} for (x, y), l in zip(scaled, labels)]
    return {"type": "scatter", "series": [{"data": data}]}

def process(text_input, file_obj):
    # 先收集所有句子
    sentences = []

    # 读取txt文件内容
    if file_obj is not None:
        try:
            # file_obj 是 tempfile.NamedTemporaryFile,直接打开它的 file_obj.name
            with open(file_obj.name, "r", encoding="utf-8") as f:
                content = f.read()
            lines = content.strip().splitlines()
            sentences.extend([line.strip() for line in lines if line.strip()])
        except Exception as e:
            return f"❌ 文件读取失败: {str(e)}", None, None, None, None, None, None

    # 处理文本框输入
    if text_input:
        lines = text_input.strip().splitlines()
        sentences.extend([line.strip() for line in lines if line.strip()])

    # 去重
    sentences = list(dict.fromkeys(sentences))

    if len(sentences) < 2:
        return "⚠️ 请输入至少两个有效句子进行聚类", None, None, None, None, None, None

    # 聚类
    labels, embeddings, scores = cluster_sentences(sentences)

    # 生成数据
    df = pd.DataFrame({"句子": sentences, "簇ID": labels})

    force_json = generate_force_graph(sentences, labels)
    bubble_json = generate_bubble_chart(sentences, labels)
    umap_json = generate_umap_plot(embeddings, labels)

    csv_data = df.to_csv(index=False, encoding="utf-8-sig")

    return (
        f"✅ Silhouette: {scores['silhouette']:.4f}, DB: {scores['db']:.4f}",
        df,
        json.dumps(force_json, ensure_ascii=False, indent=2),
        json.dumps(bubble_json, ensure_ascii=False, indent=2),
        json.dumps(umap_json, ensure_ascii=False, indent=2),
        csv_data
    )

def csv_download(csv_str):
    return io.BytesIO(csv_str.encode("utf-8-sig"))

with gr.Blocks() as demo:
    gr.Markdown("# 中文句子语义聚类 Demo")

    with gr.Row():
        text_input = gr.Textbox(label="输入多句子(每行一句)", lines=8)
        file_input = gr.File(label="上传文本文件 (.txt)", file_types=['.txt'])

    btn = gr.Button("开始聚类")

    output_score = gr.Textbox(label="聚类指标", interactive=False)
    output_table = gr.Dataframe(headers=["句子", "簇ID"], interactive=False)
    output_force = gr.JSON(label="力导图数据")
    output_bubble = gr.JSON(label="气泡图数据")
    output_umap = gr.JSON(label="UMAP二维数据")
    output_csv = gr.File(label="导出CSV")

    btn.click(
        fn=process,
        inputs=[text_input, file_input],
        outputs=[output_score, output_table, output_force, output_bubble, output_umap, output_csv]
    )

    output_csv.download = csv_download

demo.launch()