File size: 4,901 Bytes
3c18172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
from rag_utils import RAGSystem
import argparse
import os
import logging
import shutil

# Set up logging
logging.basicConfig(level=logging.INFO, 
                   format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def main():
    parser = argparse.ArgumentParser(description="Train a RAG system on your documents")
    parser.add_argument("--pdf", type=str, help="Path to LinkedIn PDF file", default="me/linkedin.pdf")
    parser.add_argument("--summary", type=str, help="Path to summary text file", default="me/summary.txt")
    parser.add_argument("--output", type=str, help="Directory to save the RAG index", default="me/rag_index")
    parser.add_argument("--test", action="store_true", help="Run a test query after training")
    parser.add_argument("--force", action="store_true", help="Force rebuild even if index exists")
    
    args = parser.parse_args()
    
    # Check if output directory exists and handle accordingly
    if os.path.exists(args.output):
        if not args.force:
            logger.warning(f"Output directory {args.output} already exists.")
            choice = input("Do you want to (o)verwrite, (s)kip to testing, or (c)ancel? [o/s/c]: ").lower()
            if choice == 'c':
                logger.info("Operation cancelled.")
                return
            elif choice == 's':
                # Skip to testing
                logger.info("Skipping build, using existing index.")
                if args.test:
                    try:
                        test_index(args.output)
                    except Exception as e:
                        logger.error(f"Error testing index: {e}")
                return
            elif choice == 'o':
                logger.info(f"Removing existing directory {args.output}...")
                shutil.rmtree(args.output)
            else:
                logger.error("Invalid choice. Exiting.")
                return
        else:
            logger.info(f"Force flag set. Removing existing directory {args.output}...")
            shutil.rmtree(args.output)
    
    # Verify input files exist
    if not os.path.exists(args.pdf):
        logger.error(f"Error: PDF file not found at {args.pdf}")
        return
    
    if not os.path.exists(args.summary):
        logger.error(f"Error: Summary file not found at {args.summary}")
        return
    
    logger.info("Initializing RAG system...")
    rag = RAGSystem()
    
    # Process documents
    logger.info(f"Processing LinkedIn profile from {args.pdf}...")
    try:
        linkedin_count = rag.add_document(args.pdf, "LinkedIn Profile")
        logger.info(f"Added {linkedin_count} chunks from LinkedIn profile")
    except Exception as e:
        logger.error(f"Error processing LinkedIn PDF: {e}")
        return
    
    logger.info(f"Processing professional summary from {args.summary}...")
    try:
        summary_count = rag.add_document(args.summary, "Professional Summary")
        logger.info(f"Added {summary_count} chunks from professional summary")
    except Exception as e:
        logger.error(f"Error processing summary file: {e}")
        return
    
    # Validate that we have chunks and a FAISS index
    if len(rag.chunks) == 0:
        logger.error("No chunks were created. Check your input files.")
        return
    
    if rag.index is None or rag.index.ntotal == 0:
        logger.error("No index was created. Check your input files.")
        return
    
    # Save index
    logger.info(f"Saving RAG index to {args.output}...")
    try:
        rag.save_index(args.output)
        logger.info("RAG index saved successfully!")
    except Exception as e:
        logger.error(f"Error saving index: {e}")
        return
    
    # Test query if requested
    if args.test:
        test_index(args.output)
    
    logger.info("\nRAG system training complete!")
    logger.info(f"To use this RAG system in your application, load it from: {args.output}")

def test_index(index_dir):
    """Test the index with sample queries"""
    try:
        logger.info("Loading index for testing...")
        rag = RAGSystem()
        rag.load_index(index_dir)
        
        logger.info(f"Loaded index with {len(rag.chunks)} chunks")
        
        queries = [
            "What are Sagarnil's technical skills?",
            "What is Sagarnil's work experience?",
            "What educational background does Sagarnil have?"
        ]
        
        logger.info("\nTesting RAG system with sample queries:")
        for query in queries:
            logger.info(f"\nQUERY: {query}")
            context = rag.get_context_for_query(query)
            logger.info(context)
            
        logger.info("Testing complete!")
    except Exception as e:
        logger.error(f"Error during testing: {e}")
        raise

if __name__ == "__main__":
    main()