File size: 7,375 Bytes
de15097
 
 
09bc280
d2de44b
09bc280
d2de44b
 
 
 
 
09bc280
 
c231c9f
 
d2de44b
09bc280
 
 
 
 
 
d2de44b
09bc280
 
 
 
 
 
 
c231c9f
09bc280
 
 
 
 
 
 
 
 
 
 
 
 
d2de44b
 
c231c9f
09bc280
d2de44b
de15097
 
09bc280
de15097
 
09bc280
c231c9f
 
 
09bc280
 
 
 
 
 
 
 
 
de15097
09bc280
de15097
 
d2de44b
c231c9f
d2de44b
09bc280
d2de44b
 
de15097
 
 
09bc280
 
 
c231c9f
 
09bc280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c231c9f
d2de44b
09bc280
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# chatbot_handler.py
import logging
import json
from google import genai # Assuming this is the correct SDK
import os
import asyncio # Added for asyncio.to_thread

# Gemini API key configuration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', '')

client = None
# model_name = "gemini-1.0-pro" # Or your preferred model like "gemini-2.0-flash"
model_name = "gemini-1.5-flash-latest" # Using a more recent Flash model
safety_settings = []


generation_config = { # New SDK style
    "temperature": 0.7,
    "top_p": 1,
    "top_k": 1,
    "max_output_tokens": 2048,
}

# Define safety settings list to be used by both client types
common_safety_settings = [
    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]

try:
    if GEMINI_API_KEY:
        if hasattr(genai, 'Client'): # Check for older SDK structure
             client = genai.Client(api_key=GEMINI_API_KEY)
             logging.info(f"Gemini client (genai.Client) initialized with model '{model_name}' for older SDK structure.")
        else: # Fallback to current recommended practice (genai.GenerativeModel)
             genai.configure(api_key=GEMINI_API_KEY)
             client = genai.GenerativeModel(
                 model_name=model_name,
                 safety_settings=common_safety_settings,
                 generation_config=generation_config
             )
             logging.info(f"Gemini client (genai.GenerativeModel) initialized with model '{model_name}'")
    else:
        logging.error("Gemini API Key is not set.")
except Exception as e:
    logging.error(f"Failed to initialize Gemini client/model: {e}", exc_info=True)


def format_history_for_gemini(gradio_chat_history: list) -> list:
    """Converts Gradio chat history to Gemini content format."""
    gemini_contents = []
    for msg in gradio_chat_history:
        role = "user" if msg.get("role") == "user" else "model"
        content = msg.get("content")
        if isinstance(content, str):
            gemini_contents.append({"role": role, "parts": [{"text": content}]})
        elif isinstance(content, list) and len(content) > 0 and isinstance(content[0], dict) and "type" in content[0]:
            parts = []
            for part_item in content:
                if part_item.get("type") == "text":
                    parts.append({"text": part_item.get("text", "")})
            if parts:
                 gemini_contents.append({"role": role, "parts": parts})
            else:
                logging.warning(f"Skipping complex but empty content part in chat history: {content}")
        else:
            logging.warning(f"Skipping non-string/non-standard content in chat history: {content}")
    return gemini_contents


async def generate_llm_response(user_message: str, plot_id: str, plot_label: str, chat_history_for_plot: list, plot_data_summary: str = None):
    if not client:
        logging.error("Gemini client/model not initialized.")
        return "The AI model is not available. Configuration error."

    gemini_formatted_history = format_history_for_gemini(chat_history_for_plot)

    if not gemini_formatted_history:
        if not any(part.get("text", "").strip() for message in gemini_formatted_history for part in message.get("parts",[])):
             logging.error("Formatted history for Gemini is empty or contains no text.")
             return "There was an issue processing the conversation history for the AI model (empty text)."

    try:
        response = None
        if isinstance(client, genai.GenerativeModel):
            logging.debug("Using genai.GenerativeModel.generate_content_async")
            response = await client.generate_content_async(
                contents=gemini_formatted_history
            )
        elif hasattr(client, 'models') and hasattr(client.models, 'generate_content'): # Check for the synchronous method
             logging.debug("Using genai.Client.models.generate_content (synchronous via asyncio.to_thread)")
             qualified_model_name = model_name if model_name.startswith("models/") else f"models/{model_name}"
             
             # Ensure safety_settings and generation_config are passed correctly
             # to the synchronous method if it's part of this older client structure.
             # The `client.models.generate_content` might take these as direct args.
             response = await asyncio.to_thread(
                 client.models.generate_content, # The synchronous function
                 model=qualified_model_name,
                 contents=gemini_formatted_history,
                 generation_config=generation_config, # Pass the dict directly
                 safety_settings=common_safety_settings # Pass the list of dicts
             )
        else:
            logging.error(f"Gemini client is not a recognized type for generating content. Type: {type(client)}")
            return "AI model interaction error (client type)."

        if hasattr(response, 'prompt_feedback') and response.prompt_feedback and response.prompt_feedback.block_reason:
            reason = response.prompt_feedback.block_reason
            reason_name = getattr(reason, 'name', str(reason))
            logging.warning(f"Blocked by prompt feedback: {reason_name}")
            return f"Blocked due to content policy: {reason_name}."

        if response.candidates and response.candidates[0].content and response.candidates[0].content.parts:
            return "".join(part.text for part in response.candidates[0].content.parts if hasattr(part, 'text'))
        
        finish_reason = "UNKNOWN"
        if response.candidates and response.candidates[0].finish_reason:
            finish_reason_val = response.candidates[0].finish_reason
            finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val))
            
        if not (response.candidates and response.candidates[0].content and response.candidates[0].content.parts):
            logging.warning(f"No content parts in response. Finish reason: {finish_reason}")
            if finish_reason == "SAFETY":
                 return f"Response generation stopped due to safety reasons. Finish reason: {finish_reason}."
            return f"The AI model returned an empty response. Finish reason: {finish_reason}."

        return f"Unexpected response structure from AI model. Finish reason: {finish_reason}."

    except AttributeError as ae:
        logging.error(f"AttributeError during Gemini call for plot '{plot_label}': {ae}", exc_info=True)
        if "generate_content_async" in str(ae) or "generate_content" in str(ae):
            return f"AI model error: SDK method not found or mismatch. Details: {ae}"
        return f"AI model error (Attribute): {type(ae).__name__} - {ae}."
    except Exception as e:
        logging.error(f"Error generating response for plot '{plot_label}': {e}", exc_info=True)
        if "API key not valid" in str(e):
            return "AI model error: API key is not valid. Please check configuration."
        return f"An unexpected error occurred while contacting the AI model: {type(e).__name__}."