Spaces:
Running
Running
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap | |
try: | |
from google import genai | |
from google.genai import types as genai_types | |
except ImportError: | |
print("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
# Dummy classes defined here for development/debugging | |
... # KEEP YOUR EXISTING DUMMY DEFINITIONS | |
# Configuration | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
LLM_MODEL_NAME = "gemini-2.0-flash" | |
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.documents_df = documents_df | |
self.embedding_model_name = embedding_model_name | |
self.embeddings = self._embed_documents() | |
def _embed_documents(self): | |
embedded_docs = [] | |
for text in self.documents_df['text']: | |
response = client.models.embed_content( | |
model=self.embedding_model_name, | |
contents=text, | |
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") | |
) | |
embedded_docs.append(np.array(response.embeddings.values)) | |
return np.vstack(embedded_docs) | |
def retrieve_relevant_info(self, query: str, top_k=3) -> str: | |
query_embedding = client.models.embed_content( | |
model=self.embedding_model_name, | |
contents=query, | |
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") | |
) | |
query_vector = np.array(query_embedding.embeddings.values) | |
scores = np.dot(self.embeddings, query_vector) | |
top_indices = np.argsort(scores)[-top_k:][::-1] | |
context = "\n\n".join(self.documents_df.iloc[i]['text'] for i in top_indices) | |
return context | |
class EmployerBrandingAgent: | |
def __init__(self, all_dataframes: dict, rag_documents_df: pd.DataFrame): | |
self.all_dataframes = all_dataframes | |
self.schemas_representation = self._get_all_schemas_representation() | |
self.chat_history = [] | |
self.rag_system = AdvancedRAGSystem(rag_documents_df, GEMINI_EMBEDDING_MODEL_NAME) | |
logging.info("EmployerBrandingAgent initialized with Gemini") | |
def _get_all_schemas_representation(self): | |
schema_descriptions = [] | |
for key, df in self.all_dataframes.items(): | |
schema = f"DataFrame: df_{key}\nColumns: {', '.join(df.columns)}\n" | |
schema_descriptions.append(schema) | |
return "\n".join(schema_descriptions) | |
def _build_prompt(self, user_query: str) -> str: | |
prompt = f"You are an expert Employer Branding Analyst. Analyze the query based on the following DataFrames.\n" | |
prompt += self.schemas_representation | |
rag_context = self.rag_system.retrieve_relevant_info(user_query) | |
if rag_context: | |
prompt += f"\n\nAdditional Context:\n{rag_context}" | |
prompt += f"\n\nUser Query:\n{user_query}" | |
return prompt | |
async def process_query(self, user_query: str) -> str: | |
self.chat_history.append({"role": "user", "content": user_query}) | |
prompt = self._build_prompt(user_query) | |
response = client.models.generate_content( | |
model=LLM_MODEL_NAME, | |
contents=[prompt], | |
config=genai_types.GenerateContentConfig( | |
safety_settings=[ | |
genai_types.SafetySetting( | |
category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
threshold=genai_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE | |
) | |
] | |
) | |
) | |
answer = response.text.strip() | |
self.chat_history.append({"role": "assistant", "content": answer}) | |
return answer | |
def update_dataframes(self, new_dataframes: dict): | |
self.all_dataframes = new_dataframes | |
self.schemas_representation = self._get_all_schemas_representation() | |
logging.info("EmployerBrandingAgent DataFrames updated.") | |
def clear_chat_history(self): | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent chat history cleared.") | |