|
|
|
from rag_utils import RAGSystem |
|
import argparse |
|
import os |
|
import logging |
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Train a RAG system on your documents") |
|
parser.add_argument("--pdf", type=str, help="Path to LinkedIn PDF file", default="me/linkedin.pdf") |
|
parser.add_argument("--summary", type=str, help="Path to summary text file", default="me/summary.txt") |
|
parser.add_argument("--output", type=str, help="Directory to save the RAG index", default="me/rag_index") |
|
parser.add_argument("--test", action="store_true", help="Run a test query after training") |
|
parser.add_argument("--force", action="store_true", help="Force rebuild even if index exists") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if os.path.exists(args.output): |
|
if not args.force: |
|
logger.warning(f"Output directory {args.output} already exists.") |
|
choice = input("Do you want to (o)verwrite, (s)kip to testing, or (c)ancel? [o/s/c]: ").lower() |
|
if choice == 'c': |
|
logger.info("Operation cancelled.") |
|
return |
|
elif choice == 's': |
|
|
|
logger.info("Skipping build, using existing index.") |
|
if args.test: |
|
try: |
|
test_index(args.output) |
|
except Exception as e: |
|
logger.error(f"Error testing index: {e}") |
|
return |
|
elif choice == 'o': |
|
logger.info(f"Removing existing directory {args.output}...") |
|
shutil.rmtree(args.output) |
|
else: |
|
logger.error("Invalid choice. Exiting.") |
|
return |
|
else: |
|
logger.info(f"Force flag set. Removing existing directory {args.output}...") |
|
shutil.rmtree(args.output) |
|
|
|
|
|
if not os.path.exists(args.pdf): |
|
logger.error(f"Error: PDF file not found at {args.pdf}") |
|
return |
|
|
|
if not os.path.exists(args.summary): |
|
logger.error(f"Error: Summary file not found at {args.summary}") |
|
return |
|
|
|
logger.info("Initializing RAG system...") |
|
rag = RAGSystem() |
|
|
|
|
|
logger.info(f"Processing LinkedIn profile from {args.pdf}...") |
|
try: |
|
linkedin_count = rag.add_document(args.pdf, "LinkedIn Profile") |
|
logger.info(f"Added {linkedin_count} chunks from LinkedIn profile") |
|
except Exception as e: |
|
logger.error(f"Error processing LinkedIn PDF: {e}") |
|
return |
|
|
|
logger.info(f"Processing professional summary from {args.summary}...") |
|
try: |
|
summary_count = rag.add_document(args.summary, "Professional Summary") |
|
logger.info(f"Added {summary_count} chunks from professional summary") |
|
except Exception as e: |
|
logger.error(f"Error processing summary file: {e}") |
|
return |
|
|
|
|
|
if len(rag.chunks) == 0: |
|
logger.error("No chunks were created. Check your input files.") |
|
return |
|
|
|
if rag.index is None or rag.index.ntotal == 0: |
|
logger.error("No index was created. Check your input files.") |
|
return |
|
|
|
|
|
logger.info(f"Saving RAG index to {args.output}...") |
|
try: |
|
rag.save_index(args.output) |
|
logger.info("RAG index saved successfully!") |
|
except Exception as e: |
|
logger.error(f"Error saving index: {e}") |
|
return |
|
|
|
|
|
if args.test: |
|
test_index(args.output) |
|
|
|
logger.info("\nRAG system training complete!") |
|
logger.info(f"To use this RAG system in your application, load it from: {args.output}") |
|
|
|
def test_index(index_dir): |
|
"""Test the index with sample queries""" |
|
try: |
|
logger.info("Loading index for testing...") |
|
rag = RAGSystem() |
|
rag.load_index(index_dir) |
|
|
|
logger.info(f"Loaded index with {len(rag.chunks)} chunks") |
|
|
|
queries = [ |
|
"What are Sagarnil's technical skills?", |
|
"What is Sagarnil's work experience?", |
|
"What educational background does Sagarnil have?" |
|
] |
|
|
|
logger.info("\nTesting RAG system with sample queries:") |
|
for query in queries: |
|
logger.info(f"\nQUERY: {query}") |
|
context = rag.get_context_for_query(query) |
|
logger.info(context) |
|
|
|
logger.info("Testing complete!") |
|
except Exception as e: |
|
logger.error(f"Error during testing: {e}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |