File size: 4,901 Bytes
3c18172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
#!/usr/bin/env python3
from rag_utils import RAGSystem
import argparse
import os
import logging
import shutil
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Train a RAG system on your documents")
parser.add_argument("--pdf", type=str, help="Path to LinkedIn PDF file", default="me/linkedin.pdf")
parser.add_argument("--summary", type=str, help="Path to summary text file", default="me/summary.txt")
parser.add_argument("--output", type=str, help="Directory to save the RAG index", default="me/rag_index")
parser.add_argument("--test", action="store_true", help="Run a test query after training")
parser.add_argument("--force", action="store_true", help="Force rebuild even if index exists")
args = parser.parse_args()
# Check if output directory exists and handle accordingly
if os.path.exists(args.output):
if not args.force:
logger.warning(f"Output directory {args.output} already exists.")
choice = input("Do you want to (o)verwrite, (s)kip to testing, or (c)ancel? [o/s/c]: ").lower()
if choice == 'c':
logger.info("Operation cancelled.")
return
elif choice == 's':
# Skip to testing
logger.info("Skipping build, using existing index.")
if args.test:
try:
test_index(args.output)
except Exception as e:
logger.error(f"Error testing index: {e}")
return
elif choice == 'o':
logger.info(f"Removing existing directory {args.output}...")
shutil.rmtree(args.output)
else:
logger.error("Invalid choice. Exiting.")
return
else:
logger.info(f"Force flag set. Removing existing directory {args.output}...")
shutil.rmtree(args.output)
# Verify input files exist
if not os.path.exists(args.pdf):
logger.error(f"Error: PDF file not found at {args.pdf}")
return
if not os.path.exists(args.summary):
logger.error(f"Error: Summary file not found at {args.summary}")
return
logger.info("Initializing RAG system...")
rag = RAGSystem()
# Process documents
logger.info(f"Processing LinkedIn profile from {args.pdf}...")
try:
linkedin_count = rag.add_document(args.pdf, "LinkedIn Profile")
logger.info(f"Added {linkedin_count} chunks from LinkedIn profile")
except Exception as e:
logger.error(f"Error processing LinkedIn PDF: {e}")
return
logger.info(f"Processing professional summary from {args.summary}...")
try:
summary_count = rag.add_document(args.summary, "Professional Summary")
logger.info(f"Added {summary_count} chunks from professional summary")
except Exception as e:
logger.error(f"Error processing summary file: {e}")
return
# Validate that we have chunks and a FAISS index
if len(rag.chunks) == 0:
logger.error("No chunks were created. Check your input files.")
return
if rag.index is None or rag.index.ntotal == 0:
logger.error("No index was created. Check your input files.")
return
# Save index
logger.info(f"Saving RAG index to {args.output}...")
try:
rag.save_index(args.output)
logger.info("RAG index saved successfully!")
except Exception as e:
logger.error(f"Error saving index: {e}")
return
# Test query if requested
if args.test:
test_index(args.output)
logger.info("\nRAG system training complete!")
logger.info(f"To use this RAG system in your application, load it from: {args.output}")
def test_index(index_dir):
"""Test the index with sample queries"""
try:
logger.info("Loading index for testing...")
rag = RAGSystem()
rag.load_index(index_dir)
logger.info(f"Loaded index with {len(rag.chunks)} chunks")
queries = [
"What are Sagarnil's technical skills?",
"What is Sagarnil's work experience?",
"What educational background does Sagarnil have?"
]
logger.info("\nTesting RAG system with sample queries:")
for query in queries:
logger.info(f"\nQUERY: {query}")
context = rag.get_context_for_query(query)
logger.info(context)
logger.info("Testing complete!")
except Exception as e:
logger.error(f"Error during testing: {e}")
raise
if __name__ == "__main__":
main() |