strongeryongchao commited on
Commit
8efad3c
·
verified ·
1 Parent(s): 733a440

Upload clustering_utils.py

Browse files
Files changed (1) hide show
  1. clustering_utils.py +13 -8
clustering_utils.py CHANGED
@@ -1,16 +1,21 @@
1
  from sentence_transformers import SentenceTransformer
2
  import hdbscan
3
  from sklearn.metrics import silhouette_score, davies_bouldin_score
 
4
 
5
  model = SentenceTransformer("shibing624/text2vec-bge-large-chinese")
6
 
7
  def cluster_sentences(sentences):
8
- embeddings = model.encode(sentences, normalize_embeddings=True)
9
- clusterer = hdbscan.HDBSCAN(min_cluster_size=3, prediction_data=True)
10
  labels = clusterer.fit_predict(embeddings)
11
- try:
12
- sil = silhouette_score(embeddings, labels) if len(set(labels)) > 1 else -1
13
- db = davies_bouldin_score(embeddings, labels) if len(set(labels)) > 1 else -1
14
- except Exception:
15
- sil, db = -1, -1
16
- return labels, embeddings, {"silhouette": sil, "db": db}
 
 
 
 
 
1
  from sentence_transformers import SentenceTransformer
2
  import hdbscan
3
  from sklearn.metrics import silhouette_score, davies_bouldin_score
4
+ import numpy as np
5
 
6
  model = SentenceTransformer("shibing624/text2vec-bge-large-chinese")
7
 
8
  def cluster_sentences(sentences):
9
+ embeddings = model.encode(sentences)
10
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=2, metric='euclidean')
11
  labels = clusterer.fit_predict(embeddings)
12
+
13
+ valid_idxs = labels != -1
14
+ if np.sum(valid_idxs) > 1:
15
+ silhouette = silhouette_score(embeddings[valid_idxs], labels[valid_idxs])
16
+ db = davies_bouldin_score(embeddings[valid_idxs], labels[valid_idxs])
17
+ else:
18
+ silhouette, db = -1, -1
19
+
20
+ return labels, embeddings, {"silhouette": silhouette, "db": db}
21
+