career_conversation / train_rag.py
sagarnildass's picture
Upload folder using huggingface_hub
3c18172 verified
#!/usr/bin/env python3
from rag_utils import RAGSystem
import argparse
import os
import logging
import shutil
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Train a RAG system on your documents")
parser.add_argument("--pdf", type=str, help="Path to LinkedIn PDF file", default="me/linkedin.pdf")
parser.add_argument("--summary", type=str, help="Path to summary text file", default="me/summary.txt")
parser.add_argument("--output", type=str, help="Directory to save the RAG index", default="me/rag_index")
parser.add_argument("--test", action="store_true", help="Run a test query after training")
parser.add_argument("--force", action="store_true", help="Force rebuild even if index exists")
args = parser.parse_args()
# Check if output directory exists and handle accordingly
if os.path.exists(args.output):
if not args.force:
logger.warning(f"Output directory {args.output} already exists.")
choice = input("Do you want to (o)verwrite, (s)kip to testing, or (c)ancel? [o/s/c]: ").lower()
if choice == 'c':
logger.info("Operation cancelled.")
return
elif choice == 's':
# Skip to testing
logger.info("Skipping build, using existing index.")
if args.test:
try:
test_index(args.output)
except Exception as e:
logger.error(f"Error testing index: {e}")
return
elif choice == 'o':
logger.info(f"Removing existing directory {args.output}...")
shutil.rmtree(args.output)
else:
logger.error("Invalid choice. Exiting.")
return
else:
logger.info(f"Force flag set. Removing existing directory {args.output}...")
shutil.rmtree(args.output)
# Verify input files exist
if not os.path.exists(args.pdf):
logger.error(f"Error: PDF file not found at {args.pdf}")
return
if not os.path.exists(args.summary):
logger.error(f"Error: Summary file not found at {args.summary}")
return
logger.info("Initializing RAG system...")
rag = RAGSystem()
# Process documents
logger.info(f"Processing LinkedIn profile from {args.pdf}...")
try:
linkedin_count = rag.add_document(args.pdf, "LinkedIn Profile")
logger.info(f"Added {linkedin_count} chunks from LinkedIn profile")
except Exception as e:
logger.error(f"Error processing LinkedIn PDF: {e}")
return
logger.info(f"Processing professional summary from {args.summary}...")
try:
summary_count = rag.add_document(args.summary, "Professional Summary")
logger.info(f"Added {summary_count} chunks from professional summary")
except Exception as e:
logger.error(f"Error processing summary file: {e}")
return
# Validate that we have chunks and a FAISS index
if len(rag.chunks) == 0:
logger.error("No chunks were created. Check your input files.")
return
if rag.index is None or rag.index.ntotal == 0:
logger.error("No index was created. Check your input files.")
return
# Save index
logger.info(f"Saving RAG index to {args.output}...")
try:
rag.save_index(args.output)
logger.info("RAG index saved successfully!")
except Exception as e:
logger.error(f"Error saving index: {e}")
return
# Test query if requested
if args.test:
test_index(args.output)
logger.info("\nRAG system training complete!")
logger.info(f"To use this RAG system in your application, load it from: {args.output}")
def test_index(index_dir):
"""Test the index with sample queries"""
try:
logger.info("Loading index for testing...")
rag = RAGSystem()
rag.load_index(index_dir)
logger.info(f"Loaded index with {len(rag.chunks)} chunks")
queries = [
"What are Sagarnil's technical skills?",
"What is Sagarnil's work experience?",
"What educational background does Sagarnil have?"
]
logger.info("\nTesting RAG system with sample queries:")
for query in queries:
logger.info(f"\nQUERY: {query}")
context = rag.get_context_for_query(query)
logger.info(context)
logger.info("Testing complete!")
except Exception as e:
logger.error(f"Error during testing: {e}")
raise
if __name__ == "__main__":
main()