chat-with-avd-doc / loader.py
rogerscuall's picture
Upload folder using huggingface_hub
890d952 verified
# /// script
# dependencies = [
# "langchain_community",
# "langchain_core",
# ]
# ///
"""
Enhanced loader script for creating FAISS vector database from Markdown documentation
with improved header metadata extraction.
"""
import os
import re
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
DOCS_DIR = "documentation"
DEVICE_DOCS_PATH = os.path.join(DOCS_DIR, "devices")
FABRIC_DOCS_PATH = os.path.join(DOCS_DIR, "fabric")
FAISS_INDEX_PATH = "faiss_index"
def extract_header_context(content, chunk_start_pos):
"""
Extract the header hierarchy for a given position in the markdown content.
Returns a dict with header levels and creates header_path and section_title.
"""
lines = content[:chunk_start_pos].split('\n')
headers = {}
# Track the current header hierarchy
for line in lines:
line = line.strip()
if line.startswith('#') and not line.startswith('#!'): # Exclude shebang
# Count the number of # to determine header level
level = len(line) - len(line.lstrip('#'))
if 1 <= level <= 5: # Only process header levels 1-5
header_text = line.lstrip('#').strip()
headers[f'header{level}'] = header_text
# Clear lower level headers when we encounter a higher level
for i in range(level + 1, 6):
if f'header{i}' in headers:
del headers[f'header{i}']
return headers
def enhance_chunk_metadata(chunk, original_content, chunk_position, file_metadata):
"""
Enhance a chunk with header metadata and other contextual information.
"""
# Start with file-level metadata
enhanced_metadata = file_metadata.copy()
# Extract header context for this chunk position
header_context = extract_header_context(original_content, chunk_position)
enhanced_metadata.update(header_context)
# Create header path from all header levels
header_path_parts = []
for i in range(1, 6): # header1 through header5
if f'header{i}' in enhanced_metadata:
header_path_parts.append(enhanced_metadata[f'header{i}'])
if header_path_parts:
enhanced_metadata['header_path'] = " > ".join(header_path_parts)
enhanced_metadata['section_title'] = header_path_parts[-1] # Most specific header
return enhanced_metadata
def load_markdown_documents_with_headers(file_paths):
"""
Loads markdown documents and creates chunks with enhanced header metadata.
"""
all_documents = []
for file_path in file_paths:
print(f"Processing: {os.path.basename(file_path)}")
# Read the raw markdown content
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Create base metadata for this file
file_metadata = {
'source': os.path.basename(file_path)
}
# Add device_name if it's a device file
if 'DCX-' in os.path.basename(file_path):
file_metadata['device_name'] = os.path.basename(file_path).replace('.md', '')
# Split content into chunks using RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=200,
separators=["\n## ", "\n### ", "\n#### ", "\n##### ", "\n\n", "\n", " ", ""]
)
chunks = text_splitter.split_text(content)
for chunk in chunks:
# Find the position of this chunk in the original content
chunk_position = content.find(chunk)
if chunk_position == -1:
# If exact match not found, try finding a shorter prefix
chunk_start = chunk[:min(100, len(chunk))]
chunk_position = content.find(chunk_start)
if chunk_position == -1:
chunk_position = 0 # Fallback to beginning
# Enhance metadata with header context
enhanced_metadata = enhance_chunk_metadata(chunk, content, chunk_position, file_metadata)
# Add device context to content if it's a device file
final_content = chunk
if 'device_name' in enhanced_metadata:
device_name = enhanced_metadata['device_name']
if not chunk.strip().startswith(f"Device: {device_name}"):
final_content = f"Device: {device_name}\\n\\n{chunk}"
# Create document with enhanced metadata
doc = Document(page_content=final_content, metadata=enhanced_metadata)
all_documents.append(doc)
return all_documents
def create_vector_db():
"""
Scans documentation folders, loads MD files with enhanced header metadata,
creates embeddings, and saves a FAISS vector database.
"""
markdown_files = []
# Collect all markdown files
for root, _, files in os.walk(DEVICE_DOCS_PATH):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
for root, _, files in os.walk(FABRIC_DOCS_PATH):
for file in files:
if file.endswith(".md"):
markdown_files.append(os.path.join(root, file))
if not markdown_files:
print("No markdown files found in the specified directories.")
return
print(f"Found {len(markdown_files)} markdown files to process.")
# Load documents with enhanced header metadata
documents = load_markdown_documents_with_headers(markdown_files)
print(f"Created {len(documents)} document chunks with header metadata.")
# Debug: Print sample metadata from first few chunks
print("\\nSample metadata from first 3 chunks:")
for i, doc in enumerate(documents[:3]):
print(f"\\nChunk {i+1}:")
print(f" Source: {doc.metadata.get('source', 'Unknown')}")
print(f" Device: {doc.metadata.get('device_name', 'N/A')}")
print(f" Header Path: {doc.metadata.get('header_path', 'No headers')}")
print(f" Section Title: {doc.metadata.get('section_title', 'No section')}")
print(f" Content Preview: {doc.page_content[:100]}...")
print("\\nCreating FAISS vector database...")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
print("Embeddings model loaded.")
# Create FAISS vector store
if not documents:
print("No documents to process for FAISS index.")
return
print("Creating FAISS index...")
vector_db = FAISS.from_documents(documents, embeddings)
print("FAISS index created.")
# Save FAISS index
vector_db.save_local(FAISS_INDEX_PATH)
print(f"FAISS index saved to {FAISS_INDEX_PATH}")
print(f"Total chunks in database: {len(documents)}")
if __name__ == "__main__":
create_vector_db()