sentcluster / clustering_utils.py
strongeryongchao's picture
Upload clustering_utils.py
8efad3c verified
raw
history blame contribute delete
754 Bytes
from sentence_transformers import SentenceTransformer
import hdbscan
from sklearn.metrics import silhouette_score, davies_bouldin_score
import numpy as np
model = SentenceTransformer("shibing624/text2vec-bge-large-chinese")
def cluster_sentences(sentences):
embeddings = model.encode(sentences)
clusterer = hdbscan.HDBSCAN(min_cluster_size=2, metric='euclidean')
labels = clusterer.fit_predict(embeddings)
valid_idxs = labels != -1
if np.sum(valid_idxs) > 1:
silhouette = silhouette_score(embeddings[valid_idxs], labels[valid_idxs])
db = davies_bouldin_score(embeddings[valid_idxs], labels[valid_idxs])
else:
silhouette, db = -1, -1
return labels, embeddings, {"silhouette": silhouette, "db": db}