LinkedinMonitor / eb_agent_module.py
GuglielmoTor's picture
Update eb_agent_module.py
efa9136 verified
raw
history blame
4.28 kB
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap
try:
from google import genai
from google.genai import types as genai_types
except ImportError:
print("Google Generative AI library not found. Please install it: pip install google-generativeai")
# Dummy classes defined here for development/debugging
... # KEEP YOUR EXISTING DUMMY DEFINITIONS
# Configuration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
LLM_MODEL_NAME = "gemini-2.0-flash"
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
client = genai.Client(api_key=GEMINI_API_KEY)
class AdvancedRAGSystem:
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
self.documents_df = documents_df
self.embedding_model_name = embedding_model_name
self.embeddings = self._embed_documents()
def _embed_documents(self):
embedded_docs = []
for text in self.documents_df['text']:
response = client.models.embed_content(
model=self.embedding_model_name,
contents=text,
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
embedded_docs.append(np.array(response.embeddings.values))
return np.vstack(embedded_docs)
def retrieve_relevant_info(self, query: str, top_k=3) -> str:
query_embedding = client.models.embed_content(
model=self.embedding_model_name,
contents=query,
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
query_vector = np.array(query_embedding.embeddings.values)
scores = np.dot(self.embeddings, query_vector)
top_indices = np.argsort(scores)[-top_k:][::-1]
context = "\n\n".join(self.documents_df.iloc[i]['text'] for i in top_indices)
return context
class EmployerBrandingAgent:
def __init__(self, all_dataframes: dict, rag_documents_df: pd.DataFrame):
self.all_dataframes = all_dataframes
self.schemas_representation = self._get_all_schemas_representation()
self.chat_history = []
self.rag_system = AdvancedRAGSystem(rag_documents_df, GEMINI_EMBEDDING_MODEL_NAME)
logging.info("EmployerBrandingAgent initialized with Gemini")
def _get_all_schemas_representation(self):
schema_descriptions = []
for key, df in self.all_dataframes.items():
schema = f"DataFrame: df_{key}\nColumns: {', '.join(df.columns)}\n"
schema_descriptions.append(schema)
return "\n".join(schema_descriptions)
def _build_prompt(self, user_query: str) -> str:
prompt = f"You are an expert Employer Branding Analyst. Analyze the query based on the following DataFrames.\n"
prompt += self.schemas_representation
rag_context = self.rag_system.retrieve_relevant_info(user_query)
if rag_context:
prompt += f"\n\nAdditional Context:\n{rag_context}"
prompt += f"\n\nUser Query:\n{user_query}"
return prompt
async def process_query(self, user_query: str) -> str:
self.chat_history.append({"role": "user", "content": user_query})
prompt = self._build_prompt(user_query)
response = client.models.generate_content(
model=LLM_MODEL_NAME,
contents=[prompt],
config=genai_types.GenerateContentConfig(
safety_settings=[
genai_types.SafetySetting(
category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=genai_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
)
]
)
)
answer = response.text.strip()
self.chat_history.append({"role": "assistant", "content": answer})
return answer
def update_dataframes(self, new_dataframes: dict):
self.all_dataframes = new_dataframes
self.schemas_representation = self._get_all_schemas_representation()
logging.info("EmployerBrandingAgent DataFrames updated.")
def clear_chat_history(self):
self.chat_history = []
logging.info("EmployerBrandingAgent chat history cleared.")