rag-medical / semantic_chunking.py
baderanas's picture
Upload 12 files
cdf244e verified
import numpy as np
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
def hybrid_split(text: str, max_len: int = 1024) -> list[str]:
"""
Split text into chunks respecting sentence boundaries when possible,
with optional overlap between chunks.
Args:
text: The text to split
max_len: Maximum length for each chunk
Returns:
List of text chunks
"""
# Normalize text
text = text.replace("\r", "").replace("\n", " ").strip()
# Extract sentences (more robust regex for sentence detection)
import re
sentences = re.split(r"(?<=[.!?])\s+", text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(sentence) > max_len:
# First add the current chunk if it exists
chunks.append(sentence)
# Normal case - see if adding the sentence exceeds max_len
elif len(current_chunk) + len(sentence) + 1 > max_len:
# Add the current chunk and start a new one
chunks.append(current_chunk)
current_chunk = ""
else:
# Add to the current chunk
if current_chunk:
current_chunk += " " + sentence
else:
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
return chunks
def cosine_similarity(vec1, vec2):
"""Calculate the cosine similarity between two vectors."""
dot_product = np.dot(vec1, vec2)
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
return dot_product / (norm_vec1 * norm_vec2)
def get_embedding(text):
"""Generate an embedding using SBERT."""
return embedding_model.encode(text, convert_to_numpy=True)
def semantic_chunking(text, threshold=0.75, max_chunk_size=8191):
"""
Splits text into semantic chunks based on sentence similarity.
- threshold: Lower = more splits, Higher = fewer splits
- max_chunk_size: Maximum size of each chunk in characters
"""
text = text.replace("\n", " ").replace("\r", " ").strip()
sentences = hybrid_split(text)
embeddings = [get_embedding(sent) for sent in sentences]
chunks = []
current_chunk = [sentences[0]]
for i in range(1, len(sentences)):
sim = cosine_similarity(embeddings[i - 1], embeddings[i])
if (
sim < threshold
or len(" ".join(current_chunk + [sentences[i]])) > max_chunk_size
):
chunks.append(" ".join(current_chunk))
current_chunk = [sentences[i]]
else:
current_chunk.append(sentences[i])
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks