File size: 4,279 Bytes
e03d275
 
 
 
 
 
 
 
 
efa9136
 
e03d275
efa9136
 
 
e64ca65
efa9136
 
 
 
a5ee064
efa9136
e03d275
a5ee064
e03d275
efa9136
 
 
 
 
 
 
 
 
 
 
e03d275
efa9136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e03d275
efa9136
 
 
 
 
 
 
9ce5589
efa9136
 
 
 
 
 
e03d275
efa9136
 
 
e03d275
56bc649
efa9136
 
 
 
e03d275
 
efa9136
911f78e
efa9136
 
 
 
 
 
 
 
 
 
 
 
 
 
5deb357
efa9136
 
 
5deb357
efa9136
 
 
 
5deb357
efa9136
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap

try:
    from google import genai
    from google.genai import types as genai_types
except ImportError:
    print("Google Generative AI library not found. Please install it: pip install google-generativeai")
    # Dummy classes defined here for development/debugging
    ...  # KEEP YOUR EXISTING DUMMY DEFINITIONS

# Configuration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") 
LLM_MODEL_NAME = "gemini-2.0-flash"
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"

client = genai.Client(api_key=GEMINI_API_KEY)

class AdvancedRAGSystem:
    def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
        self.documents_df = documents_df
        self.embedding_model_name = embedding_model_name
        self.embeddings = self._embed_documents()

    def _embed_documents(self):
        embedded_docs = []
        for text in self.documents_df['text']:
            response = client.models.embed_content(
                model=self.embedding_model_name,
                contents=text,
                config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
            )
            embedded_docs.append(np.array(response.embeddings.values))
        return np.vstack(embedded_docs)

    def retrieve_relevant_info(self, query: str, top_k=3) -> str:
        query_embedding = client.models.embed_content(
            model=self.embedding_model_name,
            contents=query,
            config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
        )
        query_vector = np.array(query_embedding.embeddings.values)

        scores = np.dot(self.embeddings, query_vector)
        top_indices = np.argsort(scores)[-top_k:][::-1]
        context = "\n\n".join(self.documents_df.iloc[i]['text'] for i in top_indices)
        return context

class EmployerBrandingAgent:
    def __init__(self, all_dataframes: dict, rag_documents_df: pd.DataFrame):
        self.all_dataframes = all_dataframes
        self.schemas_representation = self._get_all_schemas_representation()
        self.chat_history = []
        self.rag_system = AdvancedRAGSystem(rag_documents_df, GEMINI_EMBEDDING_MODEL_NAME)
        logging.info("EmployerBrandingAgent initialized with Gemini")

    def _get_all_schemas_representation(self):
        schema_descriptions = []
        for key, df in self.all_dataframes.items():
            schema = f"DataFrame: df_{key}\nColumns: {', '.join(df.columns)}\n"
            schema_descriptions.append(schema)
        return "\n".join(schema_descriptions)

    def _build_prompt(self, user_query: str) -> str:
        prompt = f"You are an expert Employer Branding Analyst. Analyze the query based on the following DataFrames.\n"
        prompt += self.schemas_representation

        rag_context = self.rag_system.retrieve_relevant_info(user_query)
        if rag_context:
            prompt += f"\n\nAdditional Context:\n{rag_context}"

        prompt += f"\n\nUser Query:\n{user_query}"
        return prompt

    async def process_query(self, user_query: str) -> str:
        self.chat_history.append({"role": "user", "content": user_query})
        prompt = self._build_prompt(user_query)

        response = client.models.generate_content(
            model=LLM_MODEL_NAME,
            contents=[prompt],
            config=genai_types.GenerateContentConfig(
                safety_settings=[
                    genai_types.SafetySetting(
                        category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                        threshold=genai_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
                    )
                ]
            )
        )

        answer = response.text.strip()
        self.chat_history.append({"role": "assistant", "content": answer})
        return answer

    def update_dataframes(self, new_dataframes: dict):
        self.all_dataframes = new_dataframes
        self.schemas_representation = self._get_all_schemas_representation()
        logging.info("EmployerBrandingAgent DataFrames updated.")

    def clear_chat_history(self):
        self.chat_history = []
        logging.info("EmployerBrandingAgent chat history cleared.")