File size: 6,187 Bytes
250b6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# algoforge_prime/core/llm_clients.py
import os
import google.generativeai as genai
from huggingface_hub import InferenceClient

# --- Configuration ---
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")

GEMINI_API_CONFIGURED = False
HF_API_CONFIGURED = False

hf_inference_client = None
google_gemini_models = {} # To store initialized Gemini model instances

# --- Initialization ---
def initialize_clients():
    global GEMINI_API_CONFIGURED, HF_API_CONFIGURED, hf_inference_client, google_gemini_models
    
    # Google Gemini
    if GOOGLE_API_KEY:
        try:
            genai.configure(api_key=GOOGLE_API_KEY)
            GEMINI_API_CONFIGURED = True
            print("INFO: llm_clients.py - Google Gemini API configured.")
        except Exception as e:
            print(f"ERROR: llm_clients.py - Failed to configure Google Gemini API: {e}")
    else:
        print("WARNING: llm_clients.py - GOOGLE_API_KEY not found.")

    # Hugging Face
    if HF_TOKEN:
        try:
            hf_inference_client = InferenceClient(token=HF_TOKEN)
            HF_API_CONFIGURED = True
            print("INFO: llm_clients.py - Hugging Face InferenceClient initialized.")
        except Exception as e:
            print(f"ERROR: llm_clients.py - Failed to initialize Hugging Face InferenceClient: {e}")
    else:
        print("WARNING: llm_clients.py - HF_TOKEN not found.")

# Call initialize_clients when the module is imported for the first time.
# However, for Gradio apps that might reload, it's often better to call this explicitly from app.py's main scope.
# For now, let's assume it's called once. If you see issues, move the call.
# initialize_clients() # Or call this from app.py

def get_gemini_model_instance(model_id, system_instruction=None):
    """Gets or creates a Gemini model instance."""
    if not GEMINI_API_CONFIGURED:
        raise ConnectionError("Google Gemini API not configured.")
    
    instance_key = model_id + ("_sys" if system_instruction else "") # Simple keying
    if instance_key not in google_gemini_models:
        try:
            google_gemini_models[instance_key] = genai.GenerativeModel(
                model_name=model_id,
                system_instruction=system_instruction
            )
            print(f"INFO: Initialized Gemini Model Instance: {instance_key}")
        except Exception as e:
            print(f"ERROR: Failed to initialize Gemini model {model_id}: {e}")
            raise  # Re-raise the exception to be caught by the caller
    return google_gemini_models[instance_key]


class LLMResponse:
    def __init__(self, text=None, error=None, success=True, raw_response=None):
        self.text = text
        self.error = error
        self.success = success
        self.raw_response = raw_response # Store original API response if needed

    def __str__(self):
        if self.success:
            return self.text if self.text is not None else ""
        return f"ERROR: {self.error}"


def call_huggingface_api(prompt_text, model_id, temperature=0.7, max_new_tokens=350, system_prompt_text=None):
    if not HF_API_CONFIGURED or not hf_inference_client:
        return LLMResponse(error="Hugging Face API not configured.", success=False)

    full_prompt = prompt_text
    if system_prompt_text: # Simple prepend, specific formatting depends on model
        full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt_text}\n<</SYS>>\n\n{prompt_text} [/INST]" # Llama-style

    try:
        use_sample = temperature > 0.0
        raw_response = hf_inference_client.text_generation(
            full_prompt, model=model_id, max_new_tokens=max_new_tokens,
            temperature=temperature if use_sample else None,
            do_sample=use_sample, stream=False
        )
        return LLMResponse(text=raw_response, raw_response=raw_response)
    except Exception as e:
        error_msg = f"HF API Error ({model_id}): {type(e).__name__} - {e}"
        print(f"ERROR: llm_clients.py - {error_msg}")
        return LLMResponse(error=error_msg, success=False, raw_response=e)

def call_gemini_api(prompt_text, model_id, temperature=0.7, max_new_tokens=400, system_prompt_text=None):
    if not GEMINI_API_CONFIGURED:
        return LLMResponse(error="Google Gemini API not configured.", success=False)

    try:
        model_instance = get_gemini_model_instance(model_id, system_instruction=system_prompt_text)
        
        generation_config = genai.types.GenerationConfig(
            temperature=temperature,
            max_output_tokens=max_new_tokens
        )
        raw_response = model_instance.generate_content(
            prompt_text, # User prompt
            generation_config=generation_config,
            stream=False
        )

        if raw_response.prompt_feedback and raw_response.prompt_feedback.block_reason:
            reason = raw_response.prompt_feedback.block_reason_message or raw_response.prompt_feedback.block_reason
            error_msg = f"Gemini API: Prompt blocked due to safety. Reason: {reason}"
            print(f"WARNING: llm_clients.py - {error_msg}")
            return LLMResponse(error=error_msg, success=False, raw_response=raw_response)

        if not raw_response.candidates or not raw_response.candidates[0].content.parts:
            finish_reason = raw_response.candidates[0].finish_reason if raw_response.candidates else "Unknown"
            if str(finish_reason).upper() == "SAFETY":
                error_msg = f"Gemini API: Response generation stopped by safety filters. Finish Reason: {finish_reason}"
            else:
                error_msg = f"Gemini API: Empty response or no content. Finish Reason: {finish_reason}"
            print(f"WARNING: llm_clients.py - {error_msg}")
            return LLMResponse(error=error_msg, success=False, raw_response=raw_response)
        
        return LLMResponse(text=raw_response.candidates[0].content.parts[0].text, raw_response=raw_response)

    except Exception as e:
        error_msg = f"Gemini API Error ({model_id}): {type(e).__name__} - {e}"
        print(f"ERROR: llm_clients.py - {error_msg}")
        return LLMResponse(error=error_msg, success=False, raw_response=e)