Spaces:
Sleeping
Sleeping
Update core/llm_clients.py
Browse files- core/llm_clients.py +106 -60
core/llm_clients.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
import os
|
3 |
import google.generativeai as genai
|
4 |
from huggingface_hub import InferenceClient
|
|
|
5 |
|
6 |
# --- Configuration ---
|
7 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
@@ -11,127 +12,172 @@ GEMINI_API_CONFIGURED = False
|
|
11 |
HF_API_CONFIGURED = False
|
12 |
|
13 |
hf_inference_client = None
|
14 |
-
|
15 |
|
16 |
-
# --- Initialization ---
|
17 |
-
def
|
18 |
-
global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
|
19 |
|
20 |
# Google Gemini
|
21 |
if GOOGLE_API_KEY:
|
22 |
try:
|
23 |
genai.configure(api_key=GOOGLE_API_KEY)
|
24 |
GEMINI_API_CONFIGURED = True
|
25 |
-
print("INFO: llm_clients.py - Google Gemini API configured.")
|
26 |
except Exception as e:
|
|
|
27 |
print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
|
28 |
else:
|
29 |
-
print("WARNING: llm_clients.py - GOOGLE_API_KEY not found.")
|
30 |
|
31 |
# Hugging Face
|
32 |
if HF_TOKEN:
|
33 |
try:
|
34 |
hf_inference_client = InferenceClient(token=HF_TOKEN)
|
35 |
HF_API_CONFIGURED = True
|
36 |
-
print("INFO: llm_clients.py - Hugging Face InferenceClient initialized.")
|
37 |
except Exception as e:
|
|
|
38 |
print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
|
39 |
else:
|
40 |
-
print("WARNING: llm_clients.py - HF_TOKEN not found.")
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
"""
|
49 |
if not GEMINI_API_CONFIGURED:
|
50 |
-
raise ConnectionError("Google Gemini API not configured.")
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
print(f"ERROR: Failed to initialize Gemini model {model_id}: {e}")
|
62 |
-
raise # Re-raise the exception to be caught by the caller
|
63 |
-
return google_gemini_models[instance_key]
|
64 |
-
|
65 |
|
66 |
class LLMResponse:
|
67 |
-
def __init__(self, text=None, error=None, success=True, raw_response=None):
|
68 |
self.text = text
|
69 |
self.error = error
|
70 |
self.success = success
|
71 |
-
self.raw_response = raw_response
|
|
|
72 |
|
73 |
def __str__(self):
|
74 |
if self.success:
|
75 |
return self.text if self.text is not None else ""
|
76 |
-
return f"ERROR: {self.error}"
|
77 |
-
|
78 |
|
79 |
-
def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=
|
80 |
if not HF_API_CONFIGURED or not hf_inference_client:
|
81 |
-
return LLMResponse(error="Hugging Face API not configured.", success=False)
|
82 |
|
83 |
full_prompt = prompt_text
|
84 |
-
|
85 |
-
|
|
|
86 |
|
87 |
try:
|
88 |
-
use_sample = temperature > 0.0
|
89 |
raw_response = hf_inference_client.text_generation(
|
90 |
full_prompt, model=model_id, max_new_tokens=max_new_tokens,
|
91 |
-
temperature=temperature if use_sample else None,
|
92 |
-
do_sample=use_sample,
|
|
|
|
|
93 |
)
|
94 |
-
return LLMResponse(text=raw_response, raw_response=raw_response)
|
95 |
except Exception as e:
|
96 |
-
error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {e}"
|
97 |
print(f"ERROR: llm_clients.py - {error_msg}")
|
98 |
-
return LLMResponse(error=error_msg, success=False, raw_response=e)
|
99 |
|
100 |
-
def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=
|
101 |
if not GEMINI_API_CONFIGURED:
|
102 |
-
return LLMResponse(error="Google Gemini API not configured.", success=False)
|
103 |
|
104 |
try:
|
105 |
-
model_instance =
|
106 |
|
107 |
generation_config = genai.types.GenerationConfig(
|
108 |
temperature=temperature,
|
109 |
max_output_tokens=max_new_tokens
|
|
|
110 |
)
|
|
|
111 |
raw_response = model_instance.generate_content(
|
112 |
prompt_text, # User prompt
|
113 |
generation_config=generation_config,
|
114 |
stream=False
|
|
|
|
|
|
|
|
|
115 |
)
|
116 |
|
117 |
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
|
118 |
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
|
119 |
-
error_msg = f"Gemini API:
|
120 |
print(f"WARNING: llm_clients.py - {error_msg}")
|
121 |
-
return LLMResponse(error=error_msg, success=False, raw_response=raw_response)
|
122 |
-
|
123 |
-
if not raw_response.candidates or
|
124 |
-
|
125 |
-
if
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
else:
|
128 |
-
error_msg = f"Gemini API: Empty response or no content. Finish Reason: {finish_reason}"
|
129 |
print(f"WARNING: llm_clients.py - {error_msg}")
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
return LLMResponse(text=
|
133 |
|
134 |
except Exception as e:
|
135 |
-
error_msg = f"Gemini API Error ({model_id}): {type(e).__name__} - {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
print(f"ERROR: llm_clients.py - {error_msg}")
|
137 |
-
return LLMResponse(error=error_msg, success=False, raw_response=e)
|
|
|
2 |
import os
|
3 |
import google.generativeai as genai
|
4 |
from huggingface_hub import InferenceClient
|
5 |
+
import time # For potential retries or delays
|
6 |
|
7 |
# --- Configuration ---
|
8 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
|
|
12 |
HF_API_CONFIGURED = False
|
13 |
|
14 |
hf_inference_client = None
|
15 |
+
google_gemini_model_instances = {} # To cache initialized Gemini model instances
|
16 |
|
17 |
+
# --- Initialization Function (to be called from app.py) ---
|
18 |
+
def initialize_all_clients():
|
19 |
+
global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client
|
20 |
|
21 |
# Google Gemini
|
22 |
if GOOGLE_API_KEY:
|
23 |
try:
|
24 |
genai.configure(api_key=GOOGLE_API_KEY)
|
25 |
GEMINI_API_CONFIGURED = True
|
26 |
+
print("INFO: llm_clients.py - Google Gemini API configured successfully.")
|
27 |
except Exception as e:
|
28 |
+
GEMINI_API_CONFIGURED = False # Ensure it's False on error
|
29 |
print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
|
30 |
else:
|
31 |
+
print("WARNING: llm_clients.py - GOOGLE_API_KEY not found in environment variables.")
|
32 |
|
33 |
# Hugging Face
|
34 |
if HF_TOKEN:
|
35 |
try:
|
36 |
hf_inference_client = InferenceClient(token=HF_TOKEN)
|
37 |
HF_API_CONFIGURED = True
|
38 |
+
print("INFO: llm_clients.py - Hugging Face InferenceClient initialized successfully.")
|
39 |
except Exception as e:
|
40 |
+
HF_API_CONFIGURED = False # Ensure it's False on error
|
41 |
print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
|
42 |
else:
|
43 |
+
print("WARNING: llm_clients.py - HF_TOKEN not found in environment variables.")
|
44 |
+
|
45 |
+
def _get_gemini_model_instance(model_id, system_instruction=None):
|
46 |
+
"""
|
47 |
+
Manages Gemini model instances.
|
48 |
+
Gemini's genai.GenerativeModel is fairly lightweight to create,
|
49 |
+
but caching can avoid repeated setup if system_instruction is complex or model loading is slow.
|
50 |
+
For now, creating a new one each time is fine unless performance becomes an issue.
|
51 |
+
"""
|
52 |
if not GEMINI_API_CONFIGURED:
|
53 |
+
raise ConnectionError("Google Gemini API not configured or configuration failed.")
|
54 |
+
try:
|
55 |
+
# For gemini-1.5 models, system_instruction is preferred.
|
56 |
+
# For older gemini-1.0, system instructions might need to be part of the 'contents'.
|
57 |
+
return genai.GenerativeModel(
|
58 |
+
model_name=model_id,
|
59 |
+
system_instruction=system_instruction
|
60 |
+
)
|
61 |
+
except Exception as e:
|
62 |
+
print(f"ERROR: llm_clients.py - Failed to get Gemini model instance for {model_id}: {e}")
|
63 |
+
raise
|
|
|
|
|
|
|
|
|
64 |
|
65 |
class LLMResponse:
|
66 |
+
def __init__(self, text=None, error=None, success=True, raw_response=None, model_id_used="unknown"):
|
67 |
self.text = text
|
68 |
self.error = error
|
69 |
self.success = success
|
70 |
+
self.raw_response = raw_response
|
71 |
+
self.model_id_used = model_id_used
|
72 |
|
73 |
def __str__(self):
|
74 |
if self.success:
|
75 |
return self.text if self.text is not None else ""
|
76 |
+
return f"ERROR (Model: {self.model_id_used}): {self.error}"
|
|
|
77 |
|
78 |
+
def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=512, system_prompt_text=None):
|
79 |
if not HF_API_CONFIGURED or not hf_inference_client:
|
80 |
+
return LLMResponse(error="Hugging Face API not configured (HF_TOKEN missing or client init failed).", success=False, model_id_used=model_id)
|
81 |
|
82 |
full_prompt = prompt_text
|
83 |
+
# Llama-style system prompt formatting; adjust if using other HF model families
|
84 |
+
if system_prompt_text:
|
85 |
+
full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]"
|
86 |
|
87 |
try:
|
88 |
+
use_sample = temperature > 0.001 # API might treat 0 as no sampling
|
89 |
raw_response = hf_inference_client.text_generation(
|
90 |
full_prompt, model=model_id, max_new_tokens=max_new_tokens,
|
91 |
+
temperature=temperature if use_sample else None, # None or omit if not sampling
|
92 |
+
do_sample=use_sample,
|
93 |
+
# top_p=0.9 if use_sample else None, # Optional
|
94 |
+
stream=False
|
95 |
)
|
96 |
+
return LLMResponse(text=raw_response, raw_response=raw_response, model_id_used=model_id)
|
97 |
except Exception as e:
|
98 |
+
error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {str(e)}"
|
99 |
print(f"ERROR: llm_clients.py - {error_msg}")
|
100 |
+
return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
|
101 |
|
102 |
+
def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=768, system_prompt_text=None):
|
103 |
if not GEMINI_API_CONFIGURED:
|
104 |
+
return LLMResponse(error="Google Gemini API not configured (GOOGLE_API_KEY missing or config failed).", success=False, model_id_used=model_id)
|
105 |
|
106 |
try:
|
107 |
+
model_instance = _get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
|
108 |
|
109 |
generation_config = genai.types.GenerationConfig(
|
110 |
temperature=temperature,
|
111 |
max_output_tokens=max_new_tokens
|
112 |
+
# top_p=0.9 # Optional
|
113 |
)
|
114 |
+
# For Gemini, the main prompt goes directly to generate_content if system_instruction is used.
|
115 |
raw_response = model_instance.generate_content(
|
116 |
prompt_text, # User prompt
|
117 |
generation_config=generation_config,
|
118 |
stream=False
|
119 |
+
# safety_settings=[ # Optional: Adjust safety settings if needed, be very careful
|
120 |
+
# {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
121 |
+
# {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
122 |
+
# ]
|
123 |
)
|
124 |
|
125 |
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
|
126 |
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
|
127 |
+
error_msg = f"Gemini API: Your prompt was blocked. Reason: {reason}. Try rephrasing."
|
128 |
print(f"WARNING: llm_clients.py - {error_msg}")
|
129 |
+
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
|
130 |
+
|
131 |
+
if not raw_response.candidates: # No candidates usually means it was blocked or an issue.
|
132 |
+
error_msg = "Gemini API: No candidates returned in response. Possibly blocked or internal error."
|
133 |
+
# Check prompt_feedback again, as it might be populated even if candidates are empty.
|
134 |
+
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
|
135 |
+
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
|
136 |
+
error_msg = f"Gemini API: Your prompt was blocked (no candidates). Reason: {reason}. Try rephrasing."
|
137 |
+
print(f"WARNING: llm_clients.py - {error_msg}")
|
138 |
+
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
|
139 |
+
|
140 |
+
|
141 |
+
# Check the first candidate
|
142 |
+
candidate = raw_response.candidates[0]
|
143 |
+
if not candidate.content or not candidate.content.parts:
|
144 |
+
finish_reason = str(candidate.finish_reason).upper()
|
145 |
+
if finish_reason == "SAFETY":
|
146 |
+
error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}."
|
147 |
+
elif finish_reason == "RECITATION":
|
148 |
+
error_msg = f"Gemini API: Response generation stopped due to recitation policy. Finish Reason: {finish_reason}."
|
149 |
+
elif finish_reason == "MAX_TOKENS":
|
150 |
+
error_msg = f"Gemini API: Response generation stopped due to max tokens. Consider increasing max_new_tokens. Finish Reason: {finish_reason}."
|
151 |
+
# In this case, there might still be partial text.
|
152 |
+
# For simplicity, we'll treat it as an incomplete generation here.
|
153 |
+
# You could choose to return partial text if desired.
|
154 |
+
# return LLMResponse(text="[PARTIAL RESPONSE - MAX TOKENS REACHED]", ..., model_id_used=model_id)
|
155 |
else:
|
156 |
+
error_msg = f"Gemini API: Empty response or no content parts. Finish Reason: {finish_reason}."
|
157 |
print(f"WARNING: llm_clients.py - {error_msg}")
|
158 |
+
# Try to get text even if finish_reason is not 'STOP' but not ideal
|
159 |
+
# This part might need refinement based on how you want to handle partial/stopped responses
|
160 |
+
partial_text = ""
|
161 |
+
if candidate.content and candidate.content.parts and candidate.content.parts[0].text:
|
162 |
+
partial_text = candidate.content.parts[0].text
|
163 |
+
if partial_text:
|
164 |
+
return LLMResponse(text=partial_text + f"\n[Note: Generation stopped due to {finish_reason}]", raw_response=raw_response, model_id_used=model_id)
|
165 |
+
else:
|
166 |
+
return LLMResponse(error=error_msg, success=False, raw_response=raw_response, model_id_used=model_id)
|
167 |
|
168 |
+
return LLMResponse(text=candidate.content.parts[0].text, raw_response=raw_response, model_id_used=model_id)
|
169 |
|
170 |
except Exception as e:
|
171 |
+
error_msg = f"Gemini API Call Error ({model_id}): {type(e).__name__} - {str(e)}"
|
172 |
+
# More specific error messages based on common Google API errors
|
173 |
+
if "API key not valid" in str(e) or "PERMISSION_DENIED" in str(e):
|
174 |
+
error_msg = f"Gemini API Error ({model_id}): API key invalid or permission denied. Check GOOGLE_API_KEY and ensure Gemini API is enabled. Original: {str(e)}"
|
175 |
+
elif "Could not find model" in str(e) or "ील नहीं मिला" in str(e): # Hindi for "model not found"
|
176 |
+
error_msg = f"Gemini API Error ({model_id}): Model ID '{model_id}' not found or inaccessible with your key. Original: {str(e)}"
|
177 |
+
elif "User location is not supported" in str(e):
|
178 |
+
error_msg = f"Gemini API Error ({model_id}): User location not supported for this model/API. Original: {str(e)}"
|
179 |
+
elif "Quota exceeded" in str(e):
|
180 |
+
error_msg = f"Gemini API Error ({model_id}): API quota exceeded. Please check your Google Cloud quotas. Original: {str(e)}"
|
181 |
+
|
182 |
print(f"ERROR: llm_clients.py - {error_msg}")
|
183 |
+
return LLMResponse(error=error_msg, success=False, raw_response=e, model_id_used=model_id)
|