Spaces:
Sleeping
Sleeping
File size: 6,187 Bytes
250b6ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# algoforge_prime/core/llm_clients.py
import os
import google.generativeai as genai
from huggingface_hub import InferenceClient
# --- Configuration ---
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
GEMINI_API_CONFIGURED = False
HF_API_CONFIGURED = False
hf_inference_client = None
google_gemini_models = {} # To store initialized Gemini model instances
# --- Initialization ---
def initialize_clients():
global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client, google_gemini_models
# Google Gemini
if GOOGLE_API_KEY:
try:
genai.configure(api_key=GOOGLE_API_KEY)
GEMINI_API_CONFIGURED = True
print("INFO: llm_clients.py - Google Gemini API configured.")
except Exception as e:
print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
else:
print("WARNING: llm_clients.py - GOOGLE_API_KEY not found.")
# Hugging Face
if HF_TOKEN:
try:
hf_inference_client = InferenceClient(token=HF_TOKEN)
HF_API_CONFIGURED = True
print("INFO: llm_clients.py - Hugging Face InferenceClient initialized.")
except Exception as e:
print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
else:
print("WARNING: llm_clients.py - HF_TOKEN not found.")
# Call initialize_clients when the module is imported for the first time.
# However, for Gradio apps that might reload, it's often better to call this explicitly from app.py's main scope.
# For now, let's assume it's called once. If you see issues, move the call.
# initialize_clients() # Or call this from app.py
def get_gemini_model_instance(model_id, system_instruction=None):
"""Gets or creates a Gemini model instance."""
if not GEMINI_API_CONFIGURED:
raise ConnectionError("Google Gemini API not configured.")
instance_key = model_id + ("_sys" if system_instruction else "") # Simple keying
if instance_key not in google_gemini_models:
try:
google_gemini_models[instance_key] = genai.GenerativeModel(
model_name=model_id,
system_instruction=system_instruction
)
print(f"INFO: Initialized Gemini Model Instance: {instance_key}")
except Exception as e:
print(f"ERROR: Failed to initialize Gemini model {model_id}: {e}")
raise # Re-raise the exception to be caught by the caller
return google_gemini_models[instance_key]
class LLMResponse:
def __init__(self, text=None, error=None, success=True, raw_response=None):
self.text = text
self.error = error
self.success = success
self.raw_response = raw_response # Store original API response if needed
def __str__(self):
if self.success:
return self.text if self.text is not None else ""
return f"ERROR: {self.error}"
def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=350, system_prompt_text=None):
if not HF_API_CONFIGURED or not hf_inference_client:
return LLMResponse(error="Hugging Face API not configured.", success=False)
full_prompt = prompt_text
if system_prompt_text: # Simple prepend, specific formatting depends on model
full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]" # Llama-style
try:
use_sample = temperature > 0.0
raw_response = hf_inference_client.text_generation(
full_prompt, model=model_id, max_new_tokens=max_new_tokens,
temperature=temperature if use_sample else None,
do_sample=use_sample, stream=False
)
return LLMResponse(text=raw_response, raw_response=raw_response)
except Exception as e:
error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {e}"
print(f"ERROR: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=e)
def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=400, system_prompt_text=None):
if not GEMINI_API_CONFIGURED:
return LLMResponse(error="Google Gemini API not configured.", success=False)
try:
model_instance = get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
generation_config = genai.types.GenerationConfig(
temperature=temperature,
max_output_tokens=max_new_tokens
)
raw_response = model_instance.generate_content(
prompt_text, # User prompt
generation_config=generation_config,
stream=False
)
if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
error_msg = f"Gemini API: Prompt blocked due to safety. Reason: {reason}"
print(f"WARNING: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=raw_response)
if not raw_response.candidates or not raw_response.candidates[0].content.parts:
finish_reason = raw_response.candidates[0].finish_reason if raw_response.candidates else "Unknown"
if str(finish_reason).upper() == "SAFETY":
error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}"
else:
error_msg = f"Gemini API: Empty response or no content. Finish Reason: {finish_reason}"
print(f"WARNING: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=raw_response)
return LLMResponse(text=raw_response.candidates[0].content.parts[0].text, raw_response=raw_response)
except Exception as e:
error_msg = f"Gemini API Error ({model_id}): {type(e).__name__} - {e}"
print(f"ERROR: llm_clients.py - {error_msg}")
return LLMResponse(error=error_msg, success=False, raw_response=e) |