File size: 9,081 Bytes
092d6a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c9cf2b
092d6a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3fc99b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import os
import re
import numpy as np
from typing import List, Dict, Any, Optional
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
from fastapi import FastAPI, Query, HTTPException
from pydantic import BaseModel
import google.generativeai as genai
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configure Google Gemini API
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("GEMINI_API_KEY environment variable not set")
genai.configure(api_key=GEMINI_API_KEY)

# Initialize FastAPI app
app = FastAPI(
    title="SHL Assessment Recommendation API",
    description="API for recommending SHL assessments based on job descriptions or queries",
    version="1.0.0"
)

# Path to the data file
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT_DIR, "data", "processed")
# ASSESSMENTS_PATH = os.path.join(DATA_DIR, "shl_test_solutions.csv")
# ASSESSMENTS_PATH = os.path.join(ROOT_DIR, "data", "processed", "shl_test_solutions.csv")

ASSESSMENTS_PATH = r"shl_test_solutions.csv"

# Ensure data directory exists
os.makedirs(DATA_DIR, exist_ok=True)
# Load and prepare data
class RecommendationSystem:
    def __init__(self, data_path: str):
        self.df = pd.read_csv(data_path)
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # Clean and prepare data
        self.prepare_data()
        
        # Create embeddings
        self.create_embeddings()
        
        # Initialize Gemini model for query enhancement
        self.gemini_model = genai.GenerativeModel('gemini-1.5-pro')
    
    def prepare_data(self):
        """Clean and prepare the assessment data"""
        # Ensure all text columns are strings
        text_cols = ['name', 'description', 'job_levels', 'test_types_expanded']
        for col in text_cols:
            if col in self.df.columns:
                self.df[col] = self.df[col].fillna('').astype(str)
        
        # Extract duration in minutes as numeric value
        self.df['duration_minutes'] = self.df['duration'].apply(
            lambda x: int(re.search(r'(\d+)', str(x)).group(1)) 
            if isinstance(x, str) and re.search(r'(\d+)', str(x)) 
            else 60  # Default value
        )
    
    def create_embeddings(self):
        """Create embeddings for assessments"""
        # Create rich text representation for each assessment
        self.df['combined_text'] = self.df.apply(
            lambda row: f"Assessment: {row['name']}. "
                       f"Description: {row['description']}. "
                       f"Job Levels: {row['job_levels']}. "
                       f"Test Types: {row['test_types_expanded']}. "
                       f"Duration: {row['duration']}.",
            axis=1
        )
        
        # Generate embeddings
        print("Generating embeddings for assessments...")
        self.embeddings = self.model.encode(self.df['combined_text'].tolist())
        
        # Create FAISS index for fast similarity search
        self.dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(self.dimension)
        self.index.add(np.array(self.embeddings).astype('float32'))
        print(f"Created FAISS index with {len(self.df)} assessments")
    
    def enhance_query(self, query: str) -> str:
        """Use Gemini to enhance the query with assessment-relevant terms"""
        prompt = f"""
        I need to find SHL assessments based on this query: "{query}"
        
        Please reformulate this query to include specific skills, job roles, and assessment criteria 
        that would help in finding relevant technical assessments. Focus on keywords like programming 
        languages, technical skills, job levels, and any time constraints mentioned.
        
        Return only the reformulated query without any explanations or additional text.
        """
        
        try:
            response = self.gemini_model.generate_content(prompt)
            enhanced_query = response.text.strip()
            print(f"Original query: {query}")
            print(f"Enhanced query: {enhanced_query}")
            return enhanced_query
        except Exception as e:
            print(f"Error enhancing query with Gemini: {e}")
            return query  # Return original query if enhancement fails
    
    def parse_duration_constraint(self, query: str) -> Optional[int]:
        """Extract duration constraint from query"""
        # Look for patterns like "within 45 minutes", "less than 30 minutes", etc.
        patterns = [
            r"(?:within|in|under|less than|no more than)\s+(\d+)\s+(?:min|mins|minutes)",
            r"(\d+)\s+(?:min|mins|minutes)(?:\s+(?:or less|max|maximum|limit))",
            r"(?:max|maximum|limit)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)",
            r"(?:time limit|duration)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)",
            r"(?:completed in|takes|duration of)\s+(\d+)\s+(?:min|mins|minutes)"
        ]
        
        for pattern in patterns:
            match = re.search(pattern, query, re.IGNORECASE)
            if match:
                return int(match.group(1))
        
        return None
    
    def recommend(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
        """Recommend assessments based on query"""
        # Enhance query using Gemini
        enhanced_query = self.enhance_query(query)
        
        # Extract duration constraint if any
        duration_limit = self.parse_duration_constraint(query)
        
        # Generate embedding for the query
        query_embedding = self.model.encode([enhanced_query])
        
        # Search for similar assessments
        D, I = self.index.search(np.array(query_embedding).astype('float32'), len(self.df))
        
        # Get the indices of the most similar assessments
        indices = I[0]
        
        # Apply duration filter if specified
        if duration_limit:
            filtered_indices = [
                idx for idx in indices 
                if self.df.iloc[idx]['duration_minutes'] <= duration_limit
            ]
            indices = filtered_indices if filtered_indices else indices
        
        # Prepare results, limiting to max_results
        results = []
        for idx in indices[:max_results]:
            assessment = self.df.iloc[idx]
            results.append({
                "name": assessment["name"],
                "url": assessment["url"],
                "remote_testing": assessment["remote_testing"],
                "adaptive_irt": assessment["adaptive_irt"],
                "duration": assessment["duration"],
                "test_types": assessment["test_types"],
                "test_types_expanded": assessment["test_types_expanded"],
                "description": assessment["description"],
                "job_levels": assessment["job_levels"],
                "similarity_score": float(1.0 - (D[0][list(indices).index(idx)] / 100))  # Normalize to 0-1
            })
        
        return results

# Initialize the recommendation system
try:
    recommender = RecommendationSystem(ASSESSMENTS_PATH)
    print("Recommendation system initialized successfully")
except Exception as e:
    print(f"Error initializing recommendation system: {e}")
    recommender = None

# Define API response model
class AssessmentRecommendation(BaseModel):
    name: str
    url: str
    remote_testing: str
    adaptive_irt: str
    duration: str
    test_types: str
    test_types_expanded: str
    description: str
    job_levels: str
    similarity_score: float

class RecommendationResponse(BaseModel):
    query: str
    enhanced_query: str
    recommendations: List[AssessmentRecommendation]

# Define API endpoints
@app.get("/", response_model=dict)
def root():
    """Root endpoint that returns API information"""
    return {
        "name": "SHL Assessment Recommendation API",
        "version": "1.0.0",
        "endpoints": {
            "/recommend": "GET endpoint for assessment recommendations"
        }
    }

@app.get("/recommend", response_model=RecommendationResponse)
def recommend(
    query: str = Query(..., description="Natural language query or job description text"),
    max_results: int = Query(10, ge=1, le=10, description="Maximum number of results to return")
):
    """Recommend SHL assessments based on query"""
    if not recommender:
        raise HTTPException(
            status_code=500, 
            detail="Recommendation system not initialized properly"
        )
    
    # Get enhanced query for transparency
    enhanced_query = recommender.enhance_query(query)
    
    # Get recommendations
    recommendations = recommender.recommend(query, max_results=max_results)
    
    return {
        "query": query,
        "enhanced_query": enhanced_query,
        "recommendations": recommendations
    }

# Run the application
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)