strongeryongchao commited on
Commit
eb5d45e
·
verified ·
1 Parent(s): c45bf71

Upload viz_utils.py

Browse files
Files changed (1) hide show
  1. viz_utils.py +20 -8
viz_utils.py CHANGED
@@ -5,36 +5,48 @@ from collections import defaultdict
5
  import random
6
 
7
  def color_for_label(label):
8
- random.seed(label + 1000)
 
 
 
 
 
 
9
  return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})"
10
 
11
  def generate_force_graph(sentences, labels):
12
  nodes = []
13
  links = []
14
- label_map = {}
 
15
  for i, (s, l) in enumerate(zip(sentences, labels)):
16
  color = color_for_label(l)
17
- nodes.append({"name": s, "symbolSize": 10, "category": int(l), "itemStyle": {"color": color}})
18
- label_map.setdefault(l, []).append(i)
19
 
20
  for group in label_map.values():
 
 
21
  for i in group:
 
22
  for j in group:
23
  if i < j:
24
  links.append({"source": sentences[i], "target": sentences[j]})
 
 
 
25
  return {"type": "force", "nodes": nodes, "links": links}
26
 
27
  def generate_bubble_chart(sentences, labels):
28
  counts = defaultdict(int)
29
  for l in labels:
30
  counts[l] += 1
31
- data = [{"name": f"簇{l}", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()]
32
  return {"type": "bubble", "series": [{"type": "scatter", "data": data}]}
33
 
34
  def generate_umap_plot(embeddings, labels):
35
- reducer = umap.UMAP(n_components=2)
36
  umap_emb = reducer.fit_transform(embeddings)
37
  scaled = MinMaxScaler().fit_transform(umap_emb)
38
- data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}}
39
- for (x, y), l in zip(scaled, labels)]
40
  return {"type": "scatter", "series": [{"data": data}]}
 
5
  import random
6
 
7
  def color_for_label(label):
8
+ try:
9
+ label_int = int(label)
10
+ except:
11
+ label_int = -1
12
+ if label_int < 0:
13
+ return "rgb(150,150,150)" # 噪声点(-1)用灰色
14
+ random.seed(label_int + 1000)
15
  return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})"
16
 
17
  def generate_force_graph(sentences, labels):
18
  nodes = []
19
  links = []
20
+ label_map = defaultdict(list)
21
+
22
  for i, (s, l) in enumerate(zip(sentences, labels)):
23
  color = color_for_label(l)
24
+ nodes.append({"name": s, "symbolSize": 10, "category": int(l) if l >=0 else 0, "itemStyle": {"color": color}})
25
+ label_map[l].append(i)
26
 
27
  for group in label_map.values():
28
+ # 可选:限制边数,避免边太多
29
+ max_edges_per_node = 10
30
  for i in group:
31
+ connected = 0
32
  for j in group:
33
  if i < j:
34
  links.append({"source": sentences[i], "target": sentences[j]})
35
+ connected += 1
36
+ if connected >= max_edges_per_node:
37
+ break
38
  return {"type": "force", "nodes": nodes, "links": links}
39
 
40
  def generate_bubble_chart(sentences, labels):
41
  counts = defaultdict(int)
42
  for l in labels:
43
  counts[l] += 1
44
+ data = [{"name": f"簇{l}" if l >=0 else "噪声", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()]
45
  return {"type": "bubble", "series": [{"type": "scatter", "data": data}]}
46
 
47
  def generate_umap_plot(embeddings, labels):
48
+ reducer = umap.UMAP(n_components=2, random_state=42)
49
  umap_emb = reducer.fit_transform(embeddings)
50
  scaled = MinMaxScaler().fit_transform(umap_emb)
51
+ data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}} for (x, y), l in zip(scaled, labels)]
 
52
  return {"type": "scatter", "series": [{"data": data}]}