Spaces:
Sleeping
Sleeping
import os | |
import re | |
import numpy as np | |
from typing import List, Dict, Any, Optional | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from fastapi import FastAPI, Query, HTTPException | |
from pydantic import BaseModel | |
import google.generativeai as genai | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Configure Google Gemini API | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
if not GEMINI_API_KEY: | |
raise ValueError("GEMINI_API_KEY environment variable not set") | |
genai.configure(api_key=GEMINI_API_KEY) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="SHL Assessment Recommendation API", | |
description="API for recommending SHL assessments based on job descriptions or queries", | |
version="1.0.0" | |
) | |
# Path to the data file | |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
DATA_DIR = os.path.join(ROOT_DIR, "data", "processed") | |
# ASSESSMENTS_PATH = os.path.join(DATA_DIR, "shl_test_solutions.csv") | |
# ASSESSMENTS_PATH = os.path.join(ROOT_DIR, "data", "processed", "shl_test_solutions.csv") | |
ASSESSMENTS_PATH = r"shl_test_solutions.csv" | |
# Ensure data directory exists | |
os.makedirs(DATA_DIR, exist_ok=True) | |
# Load and prepare data | |
class RecommendationSystem: | |
def __init__(self, data_path: str): | |
self.df = pd.read_csv(data_path) | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Clean and prepare data | |
self.prepare_data() | |
# Create embeddings | |
self.create_embeddings() | |
# Initialize Gemini model for query enhancement | |
self.gemini_model = genai.GenerativeModel('gemini-1.5-pro') | |
def prepare_data(self): | |
"""Clean and prepare the assessment data""" | |
# Ensure all text columns are strings | |
text_cols = ['name', 'description', 'job_levels', 'test_types_expanded'] | |
for col in text_cols: | |
if col in self.df.columns: | |
self.df[col] = self.df[col].fillna('').astype(str) | |
# Extract duration in minutes as numeric value | |
self.df['duration_minutes'] = self.df['duration'].apply( | |
lambda x: int(re.search(r'(\d+)', str(x)).group(1)) | |
if isinstance(x, str) and re.search(r'(\d+)', str(x)) | |
else 60 # Default value | |
) | |
def create_embeddings(self): | |
"""Create embeddings for assessments""" | |
# Create rich text representation for each assessment | |
self.df['combined_text'] = self.df.apply( | |
lambda row: f"Assessment: {row['name']}. " | |
f"Description: {row['description']}. " | |
f"Job Levels: {row['job_levels']}. " | |
f"Test Types: {row['test_types_expanded']}. " | |
f"Duration: {row['duration']}.", | |
axis=1 | |
) | |
# Generate embeddings | |
print("Generating embeddings for assessments...") | |
self.embeddings = self.model.encode(self.df['combined_text'].tolist()) | |
# Create FAISS index for fast similarity search | |
self.dimension = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(self.dimension) | |
self.index.add(np.array(self.embeddings).astype('float32')) | |
print(f"Created FAISS index with {len(self.df)} assessments") | |
def enhance_query(self, query: str) -> str: | |
"""Use Gemini to enhance the query with assessment-relevant terms""" | |
prompt = f""" | |
I need to find SHL assessments based on this query: "{query}" | |
Please reformulate this query to include specific skills, job roles, and assessment criteria | |
that would help in finding relevant technical assessments. Focus on keywords like programming | |
languages, technical skills, job levels, and any time constraints mentioned. | |
Return only the reformulated query without any explanations or additional text. | |
""" | |
try: | |
response = self.gemini_model.generate_content(prompt) | |
enhanced_query = response.text.strip() | |
print(f"Original query: {query}") | |
print(f"Enhanced query: {enhanced_query}") | |
return enhanced_query | |
except Exception as e: | |
print(f"Error enhancing query with Gemini: {e}") | |
return query # Return original query if enhancement fails | |
def parse_duration_constraint(self, query: str) -> Optional[int]: | |
"""Extract duration constraint from query""" | |
# Look for patterns like "within 45 minutes", "less than 30 minutes", etc. | |
patterns = [ | |
r"(?:within|in|under|less than|no more than)\s+(\d+)\s+(?:min|mins|minutes)", | |
r"(\d+)\s+(?:min|mins|minutes)(?:\s+(?:or less|max|maximum|limit))", | |
r"(?:max|maximum|limit)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)", | |
r"(?:time limit|duration)(?:\s+(?:of|is))?\s+(\d+)\s+(?:min|mins|minutes)", | |
r"(?:completed in|takes|duration of)\s+(\d+)\s+(?:min|mins|minutes)" | |
] | |
for pattern in patterns: | |
match = re.search(pattern, query, re.IGNORECASE) | |
if match: | |
return int(match.group(1)) | |
return None | |
def recommend(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: | |
"""Recommend assessments based on query""" | |
# Enhance query using Gemini | |
enhanced_query = self.enhance_query(query) | |
# Extract duration constraint if any | |
duration_limit = self.parse_duration_constraint(query) | |
# Generate embedding for the query | |
query_embedding = self.model.encode([enhanced_query]) | |
# Search for similar assessments | |
D, I = self.index.search(np.array(query_embedding).astype('float32'), len(self.df)) | |
# Get the indices of the most similar assessments | |
indices = I[0] | |
# Apply duration filter if specified | |
if duration_limit: | |
filtered_indices = [ | |
idx for idx in indices | |
if self.df.iloc[idx]['duration_minutes'] <= duration_limit | |
] | |
indices = filtered_indices if filtered_indices else indices | |
# Prepare results, limiting to max_results | |
results = [] | |
for idx in indices[:max_results]: | |
assessment = self.df.iloc[idx] | |
results.append({ | |
"name": assessment["name"], | |
"url": assessment["url"], | |
"remote_testing": assessment["remote_testing"], | |
"adaptive_irt": assessment["adaptive_irt"], | |
"duration": assessment["duration"], | |
"test_types": assessment["test_types"], | |
"test_types_expanded": assessment["test_types_expanded"], | |
"description": assessment["description"], | |
"job_levels": assessment["job_levels"], | |
"similarity_score": float(1.0 - (D[0][list(indices).index(idx)] / 100)) # Normalize to 0-1 | |
}) | |
return results | |
# Initialize the recommendation system | |
try: | |
recommender = RecommendationSystem(ASSESSMENTS_PATH) | |
print("Recommendation system initialized successfully") | |
except Exception as e: | |
print(f"Error initializing recommendation system: {e}") | |
recommender = None | |
# Define API response model | |
class AssessmentRecommendation(BaseModel): | |
name: str | |
url: str | |
remote_testing: str | |
adaptive_irt: str | |
duration: str | |
test_types: str | |
test_types_expanded: str | |
description: str | |
job_levels: str | |
similarity_score: float | |
class RecommendationResponse(BaseModel): | |
query: str | |
enhanced_query: str | |
recommendations: List[AssessmentRecommendation] | |
# Define API endpoints | |
def root(): | |
"""Root endpoint that returns API information""" | |
return { | |
"name": "SHL Assessment Recommendation API", | |
"version": "1.0.0", | |
"endpoints": { | |
"/recommend": "GET endpoint for assessment recommendations" | |
} | |
} | |
def recommend( | |
query: str = Query(..., description="Natural language query or job description text"), | |
max_results: int = Query(10, ge=1, le=10, description="Maximum number of results to return") | |
): | |
"""Recommend SHL assessments based on query""" | |
if not recommender: | |
raise HTTPException( | |
status_code=500, | |
detail="Recommendation system not initialized properly" | |
) | |
# Get enhanced query for transparency | |
enhanced_query = recommender.enhance_query(query) | |
# Get recommendations | |
recommendations = recommender.recommend(query, max_results=max_results) | |
return { | |
"query": query, | |
"enhanced_query": enhanced_query, | |
"recommendations": recommendations | |
} | |
# Run the application | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |