Spaces:
Running
Running
File size: 4,279 Bytes
e03d275 efa9136 e03d275 efa9136 e64ca65 efa9136 a5ee064 efa9136 e03d275 a5ee064 e03d275 efa9136 e03d275 efa9136 e03d275 efa9136 9ce5589 efa9136 e03d275 efa9136 e03d275 56bc649 efa9136 e03d275 efa9136 911f78e efa9136 5deb357 efa9136 5deb357 efa9136 5deb357 efa9136 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap
try:
from google import genai
from google.genai import types as genai_types
except ImportError:
print("Google Generative AI library not found. Please install it: pip install google-generativeai")
# Dummy classes defined here for development/debugging
... # KEEP YOUR EXISTING DUMMY DEFINITIONS
# Configuration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
LLM_MODEL_NAME = "gemini-2.0-flash"
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
client = genai.Client(api_key=GEMINI_API_KEY)
class AdvancedRAGSystem:
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
self.documents_df = documents_df
self.embedding_model_name = embedding_model_name
self.embeddings = self._embed_documents()
def _embed_documents(self):
embedded_docs = []
for text in self.documents_df['text']:
response = client.models.embed_content(
model=self.embedding_model_name,
contents=text,
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
embedded_docs.append(np.array(response.embeddings.values))
return np.vstack(embedded_docs)
def retrieve_relevant_info(self, query: str, top_k=3) -> str:
query_embedding = client.models.embed_content(
model=self.embedding_model_name,
contents=query,
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
)
query_vector = np.array(query_embedding.embeddings.values)
scores = np.dot(self.embeddings, query_vector)
top_indices = np.argsort(scores)[-top_k:][::-1]
context = "\n\n".join(self.documents_df.iloc[i]['text'] for i in top_indices)
return context
class EmployerBrandingAgent:
def __init__(self, all_dataframes: dict, rag_documents_df: pd.DataFrame):
self.all_dataframes = all_dataframes
self.schemas_representation = self._get_all_schemas_representation()
self.chat_history = []
self.rag_system = AdvancedRAGSystem(rag_documents_df, GEMINI_EMBEDDING_MODEL_NAME)
logging.info("EmployerBrandingAgent initialized with Gemini")
def _get_all_schemas_representation(self):
schema_descriptions = []
for key, df in self.all_dataframes.items():
schema = f"DataFrame: df_{key}\nColumns: {', '.join(df.columns)}\n"
schema_descriptions.append(schema)
return "\n".join(schema_descriptions)
def _build_prompt(self, user_query: str) -> str:
prompt = f"You are an expert Employer Branding Analyst. Analyze the query based on the following DataFrames.\n"
prompt += self.schemas_representation
rag_context = self.rag_system.retrieve_relevant_info(user_query)
if rag_context:
prompt += f"\n\nAdditional Context:\n{rag_context}"
prompt += f"\n\nUser Query:\n{user_query}"
return prompt
async def process_query(self, user_query: str) -> str:
self.chat_history.append({"role": "user", "content": user_query})
prompt = self._build_prompt(user_query)
response = client.models.generate_content(
model=LLM_MODEL_NAME,
contents=[prompt],
config=genai_types.GenerateContentConfig(
safety_settings=[
genai_types.SafetySetting(
category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=genai_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
)
]
)
)
answer = response.text.strip()
self.chat_history.append({"role": "assistant", "content": answer})
return answer
def update_dataframes(self, new_dataframes: dict):
self.all_dataframes = new_dataframes
self.schemas_representation = self._get_all_schemas_representation()
logging.info("EmployerBrandingAgent DataFrames updated.")
def clear_chat_history(self):
self.chat_history = []
logging.info("EmployerBrandingAgent chat history cleared.")
|