Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,240 +1,240 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import numpy as np
|
4 |
-
from typing import List, Dict, Any, Optional
|
5 |
-
import pandas as pd
|
6 |
-
from sentence_transformers import SentenceTransformer
|
7 |
-
import faiss
|
8 |
-
from fastapi import FastAPI, Query, HTTPException
|
9 |
-
from pydantic import BaseModel
|
10 |
-
import google.generativeai as genai
|
11 |
-
from dotenv import load_dotenv
|
12 |
-
|
13 |
-
# Load environment variables
|
14 |
-
load_dotenv()
|
15 |
-
|
16 |
-
# Configure Google Gemini API
|
17 |
-
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
18 |
-
if not GEMINI_API_KEY:
|
19 |
-
raise ValueError("GEMINI_API_KEY environment variable not set")
|
20 |
-
genai.configure(api_key=GEMINI_API_KEY)
|
21 |
-
|
22 |
-
# Initialize FastAPI app
|
23 |
-
app = FastAPI(
|
24 |
-
title="SHL Assessment Recommendation API",
|
25 |
-
description="API for recommending SHL assessments based on job descriptions or queries",
|
26 |
-
version="1.0.0"
|
27 |
-
)
|
28 |
-
|
29 |
-
# Path to the data file
|
30 |
-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
31 |
-
DATA_DIR = os.path.join(ROOT_DIR, "data", "processed")
|
32 |
-
# ASSESSMENTS_PATH = os.path.join(DATA_DIR, "shl_test_solutions.csv")
|
33 |
-
ASSESSMENTS_PATH = os.path.join(ROOT_DIR, "data", "processed", "shl_test_solutions.csv")
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
# Ensure data directory exists
|
38 |
-
os.makedirs(DATA_DIR, exist_ok=True)
|
39 |
-
# Load and prepare data
|
40 |
-
class RecommendationSystem:
|
41 |
-
def __init__(self, data_path: str):
|
42 |
-
self.df = pd.read_csv(data_path)
|
43 |
-
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
44 |
-
|
45 |
-
# Clean and prepare data
|
46 |
-
self.prepare_data()
|
47 |
-
|
48 |
-
# Create embeddings
|
49 |
-
self.create_embeddings()
|
50 |
-
|
51 |
-
# Initialize Gemini model for query enhancement
|
52 |
-
self.gemini_model = genai.GenerativeModel('gemini-1.5-pro')
|
53 |
-
|
54 |
-
def prepare_data(self):
|
55 |
-
"""Clean and prepare the assessment data"""
|
56 |
-
# Ensure all text columns are strings
|
57 |
-
text_cols = ['name', 'description', 'job_levels', 'test_types_expanded']
|
58 |
-
for col in text_cols:
|
59 |
-
if col in self.df.columns:
|
60 |
-
self.df[col] = self.df[col].fillna('').astype(str)
|
61 |
-
|
62 |
-
# Extract duration in minutes as numeric value
|
63 |
-
self.df['duration_minutes'] = self.df['duration'].apply(
|
64 |
-
lambda x: int(re.search(r'(\d+)', str(x)).group(1))
|
65 |
-
if isinstance(x, str) and re.search(r'(\d+)', str(x))
|
66 |
-
else 60 # Default value
|
67 |
-
)
|
68 |
-
|
69 |
-
def create_embeddings(self):
|
70 |
-
"""Create embeddings for assessments"""
|
71 |
-
# Create rich text representation for each assessment
|
72 |
-
self.df['combined_text'] = self.df.apply(
|
73 |
-
lambda row: f"Assessment: {row['name']}. "
|
74 |
-
f"Description: {row['description']}. "
|
75 |
-
f"Job Levels: {row['job_levels']}. "
|
76 |
-
f"Test Types: {row['test_types_expanded']}. "
|
77 |
-
f"Duration: {row['duration']}.",
|
78 |
-
axis=1
|
79 |
-
)
|
80 |
-
|
81 |
-
# Generate embeddings
|
82 |
-
print("Generating embeddings for assessments...")
|
83 |
-
self.embeddings = self.model.encode(self.df['combined_text'].tolist())
|
84 |
-
|
85 |
-
# Create FAISS index for fast similarity search
|
86 |
-
self.dimension = self.embeddings.shape[1]
|
87 |
-
self.index = faiss.IndexFlatL2(self.dimension)
|
88 |
-
self.index.add(np.array(self.embeddings).astype('float32'))
|
89 |
-
print(f"Created FAISS index with {len(self.df)} assessments")
|
90 |
-
|
91 |
-
def enhance_query(self, query: str) -> str:
|
92 |
-
"""Use Gemini to enhance the query with assessment-relevant terms"""
|
93 |
-
prompt = f"""
|
94 |
-
I need to find SHL assessments based on this query: "{query}"
|
95 |
-
|
96 |
-
Please reformulate this query to include specific skills, job roles, and assessment criteria
|
97 |
-
that would help in finding relevant technical assessments. Focus on keywords like programming
|
98 |
-
languages, technical skills, job levels, and any time constraints mentioned.
|
99 |
-
|
100 |
-
Return only the reformulated query without any explanations or additional text.
|
101 |
-
"""
|
102 |
-
|
103 |
-
try:
|
104 |
-
response = self.gemini_model.generate_content(prompt)
|
105 |
-
enhanced_query = response.text.strip()
|
106 |
-
print(f"Original query: {query}")
|
107 |
-
print(f"Enhanced query: {enhanced_query}")
|
108 |
-
return enhanced_query
|
109 |
-
except Exception as e:
|
110 |
-
print(f"Error enhancing query with Gemini: {e}")
|
111 |
-
return query # Return original query if enhancement fails
|
112 |
-
|
113 |
-
def parse_duration_constraint(self, query: str) -> Optional[int]:
|
114 |
-
"""Extract duration constraint from query"""
|
115 |
-
# Look for patterns like "within 45 minutes", "less than 30 minutes", etc.
|
116 |
-
patterns = [
|
117 |
-
r"(?:within|in|under|less than|no more than)\s+(\d+)\s+(?:min|mins|minutes)",
|
118 |
-
r"(\d+)\s+(?:min|mins|minutes)(?:\s+(?:or less|max|maximum|limit))",
|
119 |
-
r"(?:max|maximum|limit)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)",
|
120 |
-
r"(?:time limit|duration)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)",
|
121 |
-
r"(?:completed in|takes|duration of)\s+(\d+)\s+(?:min|mins|minutes)"
|
122 |
-
]
|
123 |
-
|
124 |
-
for pattern in patterns:
|
125 |
-
match = re.search(pattern, query, re.IGNORECASE)
|
126 |
-
if match:
|
127 |
-
return int(match.group(1))
|
128 |
-
|
129 |
-
return None
|
130 |
-
|
131 |
-
def recommend(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
|
132 |
-
"""Recommend assessments based on query"""
|
133 |
-
# Enhance query using Gemini
|
134 |
-
enhanced_query = self.enhance_query(query)
|
135 |
-
|
136 |
-
# Extract duration constraint if any
|
137 |
-
duration_limit = self.parse_duration_constraint(query)
|
138 |
-
|
139 |
-
# Generate embedding for the query
|
140 |
-
query_embedding = self.model.encode([enhanced_query])
|
141 |
-
|
142 |
-
# Search for similar assessments
|
143 |
-
D, I = self.index.search(np.array(query_embedding).astype('float32'), len(self.df))
|
144 |
-
|
145 |
-
# Get the indices of the most similar assessments
|
146 |
-
indices = I[0]
|
147 |
-
|
148 |
-
# Apply duration filter if specified
|
149 |
-
if duration_limit:
|
150 |
-
filtered_indices = [
|
151 |
-
idx for idx in indices
|
152 |
-
if self.df.iloc[idx]['duration_minutes'] <= duration_limit
|
153 |
-
]
|
154 |
-
indices = filtered_indices if filtered_indices else indices
|
155 |
-
|
156 |
-
# Prepare results, limiting to max_results
|
157 |
-
results = []
|
158 |
-
for idx in indices[:max_results]:
|
159 |
-
assessment = self.df.iloc[idx]
|
160 |
-
results.append({
|
161 |
-
"name": assessment["name"],
|
162 |
-
"url": assessment["url"],
|
163 |
-
"remote_testing": assessment["remote_testing"],
|
164 |
-
"adaptive_irt": assessment["adaptive_irt"],
|
165 |
-
"duration": assessment["duration"],
|
166 |
-
"test_types": assessment["test_types"],
|
167 |
-
"test_types_expanded": assessment["test_types_expanded"],
|
168 |
-
"description": assessment["description"],
|
169 |
-
"job_levels": assessment["job_levels"],
|
170 |
-
"similarity_score": float(1.0 - (D[0][list(indices).index(idx)] / 100)) # Normalize to 0-1
|
171 |
-
})
|
172 |
-
|
173 |
-
return results
|
174 |
-
|
175 |
-
# Initialize the recommendation system
|
176 |
-
try:
|
177 |
-
recommender = RecommendationSystem(ASSESSMENTS_PATH)
|
178 |
-
print("Recommendation system initialized successfully")
|
179 |
-
except Exception as e:
|
180 |
-
print(f"Error initializing recommendation system: {e}")
|
181 |
-
recommender = None
|
182 |
-
|
183 |
-
# Define API response model
|
184 |
-
class AssessmentRecommendation(BaseModel):
|
185 |
-
name: str
|
186 |
-
url: str
|
187 |
-
remote_testing: str
|
188 |
-
adaptive_irt: str
|
189 |
-
duration: str
|
190 |
-
test_types: str
|
191 |
-
test_types_expanded: str
|
192 |
-
description: str
|
193 |
-
job_levels: str
|
194 |
-
similarity_score: float
|
195 |
-
|
196 |
-
class RecommendationResponse(BaseModel):
|
197 |
-
query: str
|
198 |
-
enhanced_query: str
|
199 |
-
recommendations: List[AssessmentRecommendation]
|
200 |
-
|
201 |
-
# Define API endpoints
|
202 |
-
@app.get("/", response_model=dict)
|
203 |
-
def root():
|
204 |
-
"""Root endpoint that returns API information"""
|
205 |
-
return {
|
206 |
-
"name": "SHL Assessment Recommendation API",
|
207 |
-
"version": "1.0.0",
|
208 |
-
"endpoints": {
|
209 |
-
"/recommend": "GET endpoint for assessment recommendations"
|
210 |
-
}
|
211 |
-
}
|
212 |
-
|
213 |
-
@app.get("/recommend", response_model=RecommendationResponse)
|
214 |
-
def recommend(
|
215 |
-
query: str = Query(..., description="Natural language query or job description text"),
|
216 |
-
max_results: int = Query(10, ge=1, le=10, description="Maximum number of results to return")
|
217 |
-
):
|
218 |
-
"""Recommend SHL assessments based on query"""
|
219 |
-
if not recommender:
|
220 |
-
raise HTTPException(
|
221 |
-
status_code=500,
|
222 |
-
detail="Recommendation system not initialized properly"
|
223 |
-
)
|
224 |
-
|
225 |
-
# Get enhanced query for transparency
|
226 |
-
enhanced_query = recommender.enhance_query(query)
|
227 |
-
|
228 |
-
# Get recommendations
|
229 |
-
recommendations = recommender.recommend(query, max_results=max_results)
|
230 |
-
|
231 |
-
return {
|
232 |
-
"query": query,
|
233 |
-
"enhanced_query": enhanced_query,
|
234 |
-
"recommendations": recommendations
|
235 |
-
}
|
236 |
-
|
237 |
-
# Run the application
|
238 |
-
if __name__ == "__main__":
|
239 |
-
import uvicorn
|
240 |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import numpy as np
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
import pandas as pd
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
import faiss
|
8 |
+
from fastapi import FastAPI, Query, HTTPException
|
9 |
+
from pydantic import BaseModel
|
10 |
+
import google.generativeai as genai
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
# Load environment variables
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
# Configure Google Gemini API
|
17 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
18 |
+
if not GEMINI_API_KEY:
|
19 |
+
raise ValueError("GEMINI_API_KEY environment variable not set")
|
20 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
21 |
+
|
22 |
+
# Initialize FastAPI app
|
23 |
+
app = FastAPI(
|
24 |
+
title="SHL Assessment Recommendation API",
|
25 |
+
description="API for recommending SHL assessments based on job descriptions or queries",
|
26 |
+
version="1.0.0"
|
27 |
+
)
|
28 |
+
|
29 |
+
# Path to the data file
|
30 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
31 |
+
DATA_DIR = os.path.join(ROOT_DIR, "data", "processed")
|
32 |
+
# ASSESSMENTS_PATH = os.path.join(DATA_DIR, "shl_test_solutions.csv")
|
33 |
+
# ASSESSMENTS_PATH = os.path.join(ROOT_DIR, "data", "processed", "shl_test_solutions.csv")
|
34 |
+
|
35 |
+
ASSESSMENTS_PATH = r"data\processed\shl_test_solutions.csv"
|
36 |
+
|
37 |
+
# Ensure data directory exists
|
38 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
39 |
+
# Load and prepare data
|
40 |
+
class RecommendationSystem:
|
41 |
+
def __init__(self, data_path: str):
|
42 |
+
self.df = pd.read_csv(data_path)
|
43 |
+
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
44 |
+
|
45 |
+
# Clean and prepare data
|
46 |
+
self.prepare_data()
|
47 |
+
|
48 |
+
# Create embeddings
|
49 |
+
self.create_embeddings()
|
50 |
+
|
51 |
+
# Initialize Gemini model for query enhancement
|
52 |
+
self.gemini_model = genai.GenerativeModel('gemini-1.5-pro')
|
53 |
+
|
54 |
+
def prepare_data(self):
|
55 |
+
"""Clean and prepare the assessment data"""
|
56 |
+
# Ensure all text columns are strings
|
57 |
+
text_cols = ['name', 'description', 'job_levels', 'test_types_expanded']
|
58 |
+
for col in text_cols:
|
59 |
+
if col in self.df.columns:
|
60 |
+
self.df[col] = self.df[col].fillna('').astype(str)
|
61 |
+
|
62 |
+
# Extract duration in minutes as numeric value
|
63 |
+
self.df['duration_minutes'] = self.df['duration'].apply(
|
64 |
+
lambda x: int(re.search(r'(\d+)', str(x)).group(1))
|
65 |
+
if isinstance(x, str) and re.search(r'(\d+)', str(x))
|
66 |
+
else 60 # Default value
|
67 |
+
)
|
68 |
+
|
69 |
+
def create_embeddings(self):
|
70 |
+
"""Create embeddings for assessments"""
|
71 |
+
# Create rich text representation for each assessment
|
72 |
+
self.df['combined_text'] = self.df.apply(
|
73 |
+
lambda row: f"Assessment: {row['name']}. "
|
74 |
+
f"Description: {row['description']}. "
|
75 |
+
f"Job Levels: {row['job_levels']}. "
|
76 |
+
f"Test Types: {row['test_types_expanded']}. "
|
77 |
+
f"Duration: {row['duration']}.",
|
78 |
+
axis=1
|
79 |
+
)
|
80 |
+
|
81 |
+
# Generate embeddings
|
82 |
+
print("Generating embeddings for assessments...")
|
83 |
+
self.embeddings = self.model.encode(self.df['combined_text'].tolist())
|
84 |
+
|
85 |
+
# Create FAISS index for fast similarity search
|
86 |
+
self.dimension = self.embeddings.shape[1]
|
87 |
+
self.index = faiss.IndexFlatL2(self.dimension)
|
88 |
+
self.index.add(np.array(self.embeddings).astype('float32'))
|
89 |
+
print(f"Created FAISS index with {len(self.df)} assessments")
|
90 |
+
|
91 |
+
def enhance_query(self, query: str) -> str:
|
92 |
+
"""Use Gemini to enhance the query with assessment-relevant terms"""
|
93 |
+
prompt = f"""
|
94 |
+
I need to find SHL assessments based on this query: "{query}"
|
95 |
+
|
96 |
+
Please reformulate this query to include specific skills, job roles, and assessment criteria
|
97 |
+
that would help in finding relevant technical assessments. Focus on keywords like programming
|
98 |
+
languages, technical skills, job levels, and any time constraints mentioned.
|
99 |
+
|
100 |
+
Return only the reformulated query without any explanations or additional text.
|
101 |
+
"""
|
102 |
+
|
103 |
+
try:
|
104 |
+
response = self.gemini_model.generate_content(prompt)
|
105 |
+
enhanced_query = response.text.strip()
|
106 |
+
print(f"Original query: {query}")
|
107 |
+
print(f"Enhanced query: {enhanced_query}")
|
108 |
+
return enhanced_query
|
109 |
+
except Exception as e:
|
110 |
+
print(f"Error enhancing query with Gemini: {e}")
|
111 |
+
return query # Return original query if enhancement fails
|
112 |
+
|
113 |
+
def parse_duration_constraint(self, query: str) -> Optional[int]:
|
114 |
+
"""Extract duration constraint from query"""
|
115 |
+
# Look for patterns like "within 45 minutes", "less than 30 minutes", etc.
|
116 |
+
patterns = [
|
117 |
+
r"(?:within|in|under|less than|no more than)\s+(\d+)\s+(?:min|mins|minutes)",
|
118 |
+
r"(\d+)\s+(?:min|mins|minutes)(?:\s+(?:or less|max|maximum|limit))",
|
119 |
+
r"(?:max|maximum|limit)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)",
|
120 |
+
r"(?:time limit|duration)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)",
|
121 |
+
r"(?:completed in|takes|duration of)\s+(\d+)\s+(?:min|mins|minutes)"
|
122 |
+
]
|
123 |
+
|
124 |
+
for pattern in patterns:
|
125 |
+
match = re.search(pattern, query, re.IGNORECASE)
|
126 |
+
if match:
|
127 |
+
return int(match.group(1))
|
128 |
+
|
129 |
+
return None
|
130 |
+
|
131 |
+
def recommend(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
|
132 |
+
"""Recommend assessments based on query"""
|
133 |
+
# Enhance query using Gemini
|
134 |
+
enhanced_query = self.enhance_query(query)
|
135 |
+
|
136 |
+
# Extract duration constraint if any
|
137 |
+
duration_limit = self.parse_duration_constraint(query)
|
138 |
+
|
139 |
+
# Generate embedding for the query
|
140 |
+
query_embedding = self.model.encode([enhanced_query])
|
141 |
+
|
142 |
+
# Search for similar assessments
|
143 |
+
D, I = self.index.search(np.array(query_embedding).astype('float32'), len(self.df))
|
144 |
+
|
145 |
+
# Get the indices of the most similar assessments
|
146 |
+
indices = I[0]
|
147 |
+
|
148 |
+
# Apply duration filter if specified
|
149 |
+
if duration_limit:
|
150 |
+
filtered_indices = [
|
151 |
+
idx for idx in indices
|
152 |
+
if self.df.iloc[idx]['duration_minutes'] <= duration_limit
|
153 |
+
]
|
154 |
+
indices = filtered_indices if filtered_indices else indices
|
155 |
+
|
156 |
+
# Prepare results, limiting to max_results
|
157 |
+
results = []
|
158 |
+
for idx in indices[:max_results]:
|
159 |
+
assessment = self.df.iloc[idx]
|
160 |
+
results.append({
|
161 |
+
"name": assessment["name"],
|
162 |
+
"url": assessment["url"],
|
163 |
+
"remote_testing": assessment["remote_testing"],
|
164 |
+
"adaptive_irt": assessment["adaptive_irt"],
|
165 |
+
"duration": assessment["duration"],
|
166 |
+
"test_types": assessment["test_types"],
|
167 |
+
"test_types_expanded": assessment["test_types_expanded"],
|
168 |
+
"description": assessment["description"],
|
169 |
+
"job_levels": assessment["job_levels"],
|
170 |
+
"similarity_score": float(1.0 - (D[0][list(indices).index(idx)] / 100)) # Normalize to 0-1
|
171 |
+
})
|
172 |
+
|
173 |
+
return results
|
174 |
+
|
175 |
+
# Initialize the recommendation system
|
176 |
+
try:
|
177 |
+
recommender = RecommendationSystem(ASSESSMENTS_PATH)
|
178 |
+
print("Recommendation system initialized successfully")
|
179 |
+
except Exception as e:
|
180 |
+
print(f"Error initializing recommendation system: {e}")
|
181 |
+
recommender = None
|
182 |
+
|
183 |
+
# Define API response model
|
184 |
+
class AssessmentRecommendation(BaseModel):
|
185 |
+
name: str
|
186 |
+
url: str
|
187 |
+
remote_testing: str
|
188 |
+
adaptive_irt: str
|
189 |
+
duration: str
|
190 |
+
test_types: str
|
191 |
+
test_types_expanded: str
|
192 |
+
description: str
|
193 |
+
job_levels: str
|
194 |
+
similarity_score: float
|
195 |
+
|
196 |
+
class RecommendationResponse(BaseModel):
|
197 |
+
query: str
|
198 |
+
enhanced_query: str
|
199 |
+
recommendations: List[AssessmentRecommendation]
|
200 |
+
|
201 |
+
# Define API endpoints
|
202 |
+
@app.get("/", response_model=dict)
|
203 |
+
def root():
|
204 |
+
"""Root endpoint that returns API information"""
|
205 |
+
return {
|
206 |
+
"name": "SHL Assessment Recommendation API",
|
207 |
+
"version": "1.0.0",
|
208 |
+
"endpoints": {
|
209 |
+
"/recommend": "GET endpoint for assessment recommendations"
|
210 |
+
}
|
211 |
+
}
|
212 |
+
|
213 |
+
@app.get("/recommend", response_model=RecommendationResponse)
|
214 |
+
def recommend(
|
215 |
+
query: str = Query(..., description="Natural language query or job description text"),
|
216 |
+
max_results: int = Query(10, ge=1, le=10, description="Maximum number of results to return")
|
217 |
+
):
|
218 |
+
"""Recommend SHL assessments based on query"""
|
219 |
+
if not recommender:
|
220 |
+
raise HTTPException(
|
221 |
+
status_code=500,
|
222 |
+
detail="Recommendation system not initialized properly"
|
223 |
+
)
|
224 |
+
|
225 |
+
# Get enhanced query for transparency
|
226 |
+
enhanced_query = recommender.enhance_query(query)
|
227 |
+
|
228 |
+
# Get recommendations
|
229 |
+
recommendations = recommender.recommend(query, max_results=max_results)
|
230 |
+
|
231 |
+
return {
|
232 |
+
"query": query,
|
233 |
+
"enhanced_query": enhanced_query,
|
234 |
+
"recommendations": recommendations
|
235 |
+
}
|
236 |
+
|
237 |
+
# Run the application
|
238 |
+
if __name__ == "__main__":
|
239 |
+
import uvicorn
|
240 |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|