Spaces:
Sleeping
Sleeping
# /// script | |
# dependencies = [ | |
# "langchain_community", | |
# "langchain_core", | |
# ] | |
# /// | |
""" | |
Enhanced loader script for creating FAISS vector database from Markdown documentation | |
with improved header metadata extraction. | |
""" | |
import os | |
import re | |
from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.documents import Document | |
DOCS_DIR = "documentation" | |
DEVICE_DOCS_PATH = os.path.join(DOCS_DIR, "devices") | |
FABRIC_DOCS_PATH = os.path.join(DOCS_DIR, "fabric") | |
FAISS_INDEX_PATH = "faiss_index" | |
def extract_header_context(content, chunk_start_pos): | |
""" | |
Extract the header hierarchy for a given position in the markdown content. | |
Returns a dict with header levels and creates header_path and section_title. | |
""" | |
lines = content[:chunk_start_pos].split('\n') | |
headers = {} | |
# Track the current header hierarchy | |
for line in lines: | |
line = line.strip() | |
if line.startswith('#') and not line.startswith('#!'): # Exclude shebang | |
# Count the number of # to determine header level | |
level = len(line) - len(line.lstrip('#')) | |
if 1 <= level <= 5: # Only process header levels 1-5 | |
header_text = line.lstrip('#').strip() | |
headers[f'header{level}'] = header_text | |
# Clear lower level headers when we encounter a higher level | |
for i in range(level + 1, 6): | |
if f'header{i}' in headers: | |
del headers[f'header{i}'] | |
return headers | |
def enhance_chunk_metadata(chunk, original_content, chunk_position, file_metadata): | |
""" | |
Enhance a chunk with header metadata and other contextual information. | |
""" | |
# Start with file-level metadata | |
enhanced_metadata = file_metadata.copy() | |
# Extract header context for this chunk position | |
header_context = extract_header_context(original_content, chunk_position) | |
enhanced_metadata.update(header_context) | |
# Create header path from all header levels | |
header_path_parts = [] | |
for i in range(1, 6): # header1 through header5 | |
if f'header{i}' in enhanced_metadata: | |
header_path_parts.append(enhanced_metadata[f'header{i}']) | |
if header_path_parts: | |
enhanced_metadata['header_path'] = " > ".join(header_path_parts) | |
enhanced_metadata['section_title'] = header_path_parts[-1] # Most specific header | |
return enhanced_metadata | |
def load_markdown_documents_with_headers(file_paths): | |
""" | |
Loads markdown documents and creates chunks with enhanced header metadata. | |
""" | |
all_documents = [] | |
for file_path in file_paths: | |
print(f"Processing: {os.path.basename(file_path)}") | |
# Read the raw markdown content | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Create base metadata for this file | |
file_metadata = { | |
'source': os.path.basename(file_path) | |
} | |
# Add device_name if it's a device file | |
if 'DCX-' in os.path.basename(file_path): | |
file_metadata['device_name'] = os.path.basename(file_path).replace('.md', '') | |
# Split content into chunks using RecursiveCharacterTextSplitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=800, | |
chunk_overlap=200, | |
separators=["\n## ", "\n### ", "\n#### ", "\n##### ", "\n\n", "\n", " ", ""] | |
) | |
chunks = text_splitter.split_text(content) | |
for chunk in chunks: | |
# Find the position of this chunk in the original content | |
chunk_position = content.find(chunk) | |
if chunk_position == -1: | |
# If exact match not found, try finding a shorter prefix | |
chunk_start = chunk[:min(100, len(chunk))] | |
chunk_position = content.find(chunk_start) | |
if chunk_position == -1: | |
chunk_position = 0 # Fallback to beginning | |
# Enhance metadata with header context | |
enhanced_metadata = enhance_chunk_metadata(chunk, content, chunk_position, file_metadata) | |
# Add device context to content if it's a device file | |
final_content = chunk | |
if 'device_name' in enhanced_metadata: | |
device_name = enhanced_metadata['device_name'] | |
if not chunk.strip().startswith(f"Device: {device_name}"): | |
final_content = f"Device: {device_name}\\n\\n{chunk}" | |
# Create document with enhanced metadata | |
doc = Document(page_content=final_content, metadata=enhanced_metadata) | |
all_documents.append(doc) | |
return all_documents | |
def create_vector_db(): | |
""" | |
Scans documentation folders, loads MD files with enhanced header metadata, | |
creates embeddings, and saves a FAISS vector database. | |
""" | |
markdown_files = [] | |
# Collect all markdown files | |
for root, _, files in os.walk(DEVICE_DOCS_PATH): | |
for file in files: | |
if file.endswith(".md"): | |
markdown_files.append(os.path.join(root, file)) | |
for root, _, files in os.walk(FABRIC_DOCS_PATH): | |
for file in files: | |
if file.endswith(".md"): | |
markdown_files.append(os.path.join(root, file)) | |
if not markdown_files: | |
print("No markdown files found in the specified directories.") | |
return | |
print(f"Found {len(markdown_files)} markdown files to process.") | |
# Load documents with enhanced header metadata | |
documents = load_markdown_documents_with_headers(markdown_files) | |
print(f"Created {len(documents)} document chunks with header metadata.") | |
# Debug: Print sample metadata from first few chunks | |
print("\\nSample metadata from first 3 chunks:") | |
for i, doc in enumerate(documents[:3]): | |
print(f"\\nChunk {i+1}:") | |
print(f" Source: {doc.metadata.get('source', 'Unknown')}") | |
print(f" Device: {doc.metadata.get('device_name', 'N/A')}") | |
print(f" Header Path: {doc.metadata.get('header_path', 'No headers')}") | |
print(f" Section Title: {doc.metadata.get('section_title', 'No section')}") | |
print(f" Content Preview: {doc.page_content[:100]}...") | |
print("\\nCreating FAISS vector database...") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
print("Embeddings model loaded.") | |
# Create FAISS vector store | |
if not documents: | |
print("No documents to process for FAISS index.") | |
return | |
print("Creating FAISS index...") | |
vector_db = FAISS.from_documents(documents, embeddings) | |
print("FAISS index created.") | |
# Save FAISS index | |
vector_db.save_local(FAISS_INDEX_PATH) | |
print(f"FAISS index saved to {FAISS_INDEX_PATH}") | |
print(f"Total chunks in database: {len(documents)}") | |
if __name__ == "__main__": | |
create_vector_db() | |