mgbam commited on
Commit
b967045
·
verified ·
1 Parent(s): ded730b

Update core/llm_clients.py

Browse files
Files changed (1) hide show
  1. core/llm_clients.py +106 -60
core/llm_clients.py CHANGED
@@ -2,6 +2,7 @@
2
  import os
3
  import google.generativeai as genai
4
  from huggingface_hub import InferenceClient
 
5
 
6
  # --- Configuration ---
7
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
@@ -11,127 +12,172 @@ GEMINI_API_CONFIGURED = False
11
  HF_API_CONFIGURED = False
12
 
13
  hf_inference_client = None
14
- google_gemini_models = {} # To store initialized Gemini model instances
15
 
16
- # --- Initialization ---
17
- def initialize_clients():
18
- global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client, google_gemini_models
19
 
20
  # Google Gemini
21
  if GOOGLE_API_KEY:
22
  try:
23
  genai.configure(api_key=GOOGLE_API_KEY)
24
  GEMINI_API_CONFIGURED = True
25
- print("INFO: llm_clients.py - Google Gemini API configured.")
26
  except Exception as e:
 
27
  print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
28
  else:
29
- print("WARNING: llm_clients.py - GOOGLE_API_KEY not found.")
30
 
31
  # Hugging Face
32
  if HF_TOKEN:
33
  try:
34
  hf_inference_client = InferenceClient(token=HF_TOKEN)
35
  HF_API_CONFIGURED = True
36
- print("INFO: llm_clients.py - Hugging Face InferenceClient initialized.")
37
  except Exception as e:
 
38
  print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
39
  else:
40
- print("WARNING: llm_clients.py - HF_TOKEN not found.")
41
-
42
- # Call initialize_clients when the module is imported for the first time.
43
- # However, for Gradio apps that might reload, it's often better to call this explicitly from app.py's main scope.
44
- # For now, let's assume it's called once. If you see issues, move the call.
45
- # initialize_clients() # Or call this from app.py
46
-
47
- def get_gemini_model_instance(model_id, system_instruction=None):
48
- """Gets or creates a Gemini model instance."""
49
  if not GEMINI_API_CONFIGURED:
50
- raise ConnectionError("Google Gemini API not configured.")
51
-
52
- instance_key = model_id + ("_sys" if system_instruction else "") # Simple keying
53
- if instance_key not in google_gemini_models:
54
- try:
55
- google_gemini_models[instance_key] = genai.GenerativeModel(
56
- model_name=model_id,
57
- system_instruction=system_instruction
58
- )
59
- print(f"INFO: Initialized Gemini Model Instance: {instance_key}")
60
- except Exception as e:
61
- print(f"ERROR: Failed to initialize Gemini model {model_id}: {e}")
62
- raise # Re-raise the exception to be caught by the caller
63
- return google_gemini_models[instance_key]
64
-
65
 
66
  class LLMResponse:
67
- def __init__(self, text=None, error=None, success=True, raw_response=None):
68
  self.text = text
69
  self.error = error
70
  self.success = success
71
- self.raw_response = raw_response # Store original API response if needed
 
72
 
73
  def __str__(self):
74
  if self.success:
75
  return self.text if self.text is not None else ""
76
- return f"ERROR: {self.error}"
77
-
78
 
79
- def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=350, system_prompt_text=None):
80
  if not HF_API_CONFIGURED or not hf_inference_client:
81
- return LLMResponse(error="Hugging Face API not configured.", success=False)
82
 
83
  full_prompt = prompt_text
84
- if system_prompt_text: # Simple prepend, specific formatting depends on model
85
- full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]" # Llama-style
 
86
 
87
  try:
88
- use_sample = temperature > 0.0
89
  raw_response = hf_inference_client.text_generation(
90
  full_prompt, model=model_id, max_new_tokens=max_new_tokens,
91
- temperature=temperature if use_sample else None,
92
- do_sample=use_sample, stream=False
 
 
93
  )
94
- return LLMResponse(text=raw_response, raw_response=raw_response)
95
  except Exception as e:
96
- error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {e}"
97
  print(f"ERROR: llm_clients.py - {error_msg}")
98
- return LLMResponse(error=error_msg, success=False, raw_response=e)
99
 
100
- def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=400, system_prompt_text=None):
101
  if not GEMINI_API_CONFIGURED:
102
- return LLMResponse(error="Google Gemini API not configured.", success=False)
103
 
104
  try:
105
- model_instance = get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
106
 
107
  generation_config = genai.types.GenerationConfig(
108
  temperature=temperature,
109
  max_output_tokens=max_new_tokens
 
110
  )
 
111
  raw_response = model_instance.generate_content(
112
  prompt_text, # User prompt
113
  generation_config=generation_config,
114
  stream=False
 
 
 
 
115
  )
116
 
117
  if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
118
  reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
119
- error_msg = f"Gemini API: Prompt blocked due to safety. Reason: {reason}"
120
  print(f"WARNING: llm_clients.py - {error_msg}")
121
- return LLMResponse(error=error_msg, success=False, raw_response=raw_response)
122
-
123
- if not raw_response.candidates or not raw_response.candidates[0].content.parts:
124
- finish_reason = raw_response.candidates[0].finish_reason if raw_response.candidates else "Unknown"
125
- if str(finish_reason).upper() == "SAFETY":
126
- error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
- error_msg = f"Gemini API: Empty response or no content. Finish Reason: {finish_reason}"
129
  print(f"WARNING: llm_clients.py - {error_msg}")
130
- return LLMResponse(error=error_msg, success=False, raw_response=raw_response)
 
 
 
 
 
 
 
 
131
 
132
- return LLMResponse(text=raw_response.candidates[0].content.parts[0].text, raw_response=raw_response)
133
 
134
  except Exception as e:
135
- error_msg = f"Gemini API Error ({model_id}): {type(e).__name__} - {e}"
 
 
 
 
 
 
 
 
 
 
136
  print(f"ERROR: llm_clients.py - {error_msg}")
137
- return LLMResponse(error=error_msg, success=False, raw_response=e)
 
2
  import os
3
  import google.generativeai as genai
4
  from huggingface_hub import InferenceClient
5
+ import time # For potential retries or delays
6
 
7
  # --- Configuration ---
8
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
 
12
  HF_API_CONFIGURED = False
13
 
14
  hf_inference_client = None
15
+ google_gemini_model_instances = {} # To cache initialized Gemini model instances
16
 
17
+ # --- Initialization Function (to be called from app.py) ---
18
+ def initialize_all_clients():
19
+ global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
20
 
21
  # Google Gemini
22
  if GOOGLE_API_KEY:
23
  try:
24
  genai.configure(api_key=GOOGLE_API_KEY)
25
  GEMINI_API_CONFIGURED = True
26
+ print("INFO: llm_clients.py - Google Gemini API configured successfully.")
27
  except Exception as e:
28
+ GEMINI_API_CONFIGURED = False # Ensure it's False on error
29
  print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
30
  else:
31
+ print("WARNING: llm_clients.py - GOOGLE_API_KEY not found in environment variables.")
32
 
33
  # Hugging Face
34
  if HF_TOKEN:
35
  try:
36
  hf_inference_client = InferenceClient(token=HF_TOKEN)
37
  HF_API_CONFIGURED = True
38
+ print("INFO: llm_clients.py - Hugging Face InferenceClient initialized successfully.")
39
  except Exception as e:
40
+ HF_API_CONFIGURED = False # Ensure it's False on error
41
  print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
42
  else:
43
+ print("WARNING: llm_clients.py - HF_TOKEN not found in environment variables.")
44
+
45
+ def _get_gemini_model_instance(model_id, system_instruction=None):
46
+ """
47
+ Manages Gemini model instances.
48
+ Gemini's genai.GenerativeModel is fairly lightweight to create,
49
+ but caching can avoid repeated setup if system_instruction is complex or model loading is slow.
50
+ For now, creating a new one each time is fine unless performance becomes an issue.
51
+ """
52
  if not GEMINI_API_CONFIGURED:
53
+ raise ConnectionError("Google Gemini API not configured or configuration failed.")
54
+ try:
55
+ # For gemini-1.5 models, system_instruction is preferred.
56
+ # For older gemini-1.0, system instructions might need to be part of the 'contents'.
57
+ return genai.GenerativeModel(
58
+ model_name=model_id,
59
+ system_instruction=system_instruction
60
+ )
61
+ except Exception as e:
62
+ print(f"ERROR: llm_clients.py - Failed to get Gemini model instance for {model_id}: {e}")
63
+ raise
 
 
 
 
64
 
65
  class LLMResponse:
66
+ def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"):
67
  self.text = text
68
  self.error = error
69
  self.success = success
70
+ self.raw_response = raw_response
71
+ self.model_id_used = model_id_used
72
 
73
  def __str__(self):
74
  if self.success:
75
  return self.text if self.text is not None else ""
76
+ return f"ERROR (Model: {self.model_id_used}): {self.error}"
 
77
 
78
+ def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None):
79
  if not HF_API_CONFIGURED or not hf_inference_client:
80
+ return LLMResponse(error="Hugging Face API not configured (HF_TOKEN missing or client init failed).", success=False, model_id_used=model_id)
81
 
82
  full_prompt = prompt_text
83
+ # Llama-style system prompt formatting; adjust if using other HF model families
84
+ if system_prompt_text:
85
+ full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]"
86
 
87
  try:
88
+ use_sample = temperature > 0.001 # API might treat 0 as no sampling
89
  raw_response = hf_inference_client.text_generation(
90
  full_prompt, model=model_id, max_new_tokens=max_new_tokens,
91
+ temperature=temperature if use_sample else None, # None or omit if not sampling
92
+ do_sample=use_sample,
93
+ # top_p=0.9 if use_sample else None, # Optional
94
+ stream=False
95
  )
96
+ return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id)
97
  except Exception as e:
98
+ error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {str(e)}"
99
  print(f"ERROR: llm_clients.py - {error_msg}")
100
+ return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
101
 
102
+ def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None):
103
  if not GEMINI_API_CONFIGURED:
104
+ return LLMResponse(error="Google Gemini API not configured (GOOGLE_API_KEY missing or config failed).", success=False, model_id_used=model_id)
105
 
106
  try:
107
+ model_instance = _get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
108
 
109
  generation_config = genai.types.GenerationConfig(
110
  temperature=temperature,
111
  max_output_tokens=max_new_tokens
112
+ # top_p=0.9 # Optional
113
  )
114
+ # For Gemini, the main prompt goes directly to generate_content if system_instruction is used.
115
  raw_response = model_instance.generate_content(
116
  prompt_text, # User prompt
117
  generation_config=generation_config,
118
  stream=False
119
+ # safety_settings=[ # Optional: Adjust safety settings if needed, be very careful
120
+ # {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
121
+ # {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
122
+ # ]
123
  )
124
 
125
  if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
126
  reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
127
+ error_msg = f"Gemini API: Your prompt was blocked. Reason: {reason}. Try rephrasing."
128
  print(f"WARNING: llm_clients.py - {error_msg}")
129
+ return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
130
+
131
+ if not raw_response.candidates: # No candidates usually means it was blocked or an issue.
132
+ error_msg = "Gemini API: No candidates returned in response. Possibly blocked or internal error."
133
+ # Check prompt_feedback again, as it might be populated even if candidates are empty.
134
+ if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
135
+ reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
136
+ error_msg = f"Gemini API: Your prompt was blocked (no candidates). Reason: {reason}. Try rephrasing."
137
+ print(f"WARNING: llm_clients.py - {error_msg}")
138
+ return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
139
+
140
+
141
+ # Check the first candidate
142
+ candidate = raw_response.candidates[0]
143
+ if not candidate.content or not candidate.content.parts:
144
+ finish_reason = str(candidate.finish_reason).upper()
145
+ if finish_reason == "SAFETY":
146
+ error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}."
147
+ elif finish_reason == "RECITATION":
148
+ error_msg = f"Gemini API: Response generation stopped due to recitation policy. Finish Reason: {finish_reason}."
149
+ elif finish_reason == "MAX_TOKENS":
150
+ error_msg = f"Gemini API: Response generation stopped due to max tokens. Consider increasing max_new_tokens. Finish Reason: {finish_reason}."
151
+ # In this case, there might still be partial text.
152
+ # For simplicity, we'll treat it as an incomplete generation here.
153
+ # You could choose to return partial text if desired.
154
+ # return LLMResponse(text="[PARTIAL RESPONSE - MAX TOKENS REACHED]", ..., model_id_used=model_id)
155
  else:
156
+ error_msg = f"Gemini API: Empty response or no content parts. Finish Reason: {finish_reason}."
157
  print(f"WARNING: llm_clients.py - {error_msg}")
158
+ # Try to get text even if finish_reason is not 'STOP' but not ideal
159
+ # This part might need refinement based on how you want to handle partial/stopped responses
160
+ partial_text = ""
161
+ if candidate.content and candidate.content.parts and candidate.content.parts[0].text:
162
+ partial_text = candidate.content.parts[0].text
163
+ if partial_text:
164
+ return LLMResponse(text=partial_text + f"\n[Note: Generation stopped due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id)
165
+ else:
166
+ return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
167
 
168
+ return LLMResponse(text=candidate.content.parts[0].text, raw_response=raw_response, model_id_used=model_id)
169
 
170
  except Exception as e:
171
+ error_msg = f"Gemini API Call Error ({model_id}): {type(e).__name__} - {str(e)}"
172
+ # More specific error messages based on common Google API errors
173
+ if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e):
174
+ error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled. Original: {str(e)}"
175
+ elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e): # Hindi for "model not found"
176
+ error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}"
177
+ elif "User location is not supported" in str(e):
178
+ error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}"
179
+ elif "Quota exceeded" in str(e):
180
+ error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}"
181
+
182
  print(f"ERROR: llm_clients.py - {error_msg}")
183
+ return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)