jhansss commited on
Commit
6843cf8
·
1 Parent(s): 4b87b66

Fix Gemini LLM implementation and imports

Browse files
modules/llm/__init__.py CHANGED
@@ -2,10 +2,11 @@ from .base import AbstractLLMModel
2
  from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
  from .hf_pipeline import HFTextGenerationLLM
4
  from .qwen import QwenLLM
 
5
 
6
  __all__ = [
7
  "AbstractLLMModel",
8
  "get_llm_model",
9
  "register_llm_model",
10
  "LLM_MODEL_REGISTRY",
11
- ]
 
2
  from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
  from .hf_pipeline import HFTextGenerationLLM
4
  from .qwen import QwenLLM
5
+ from .gemini import GeminiLLM
6
 
7
  __all__ = [
8
  "AbstractLLMModel",
9
  "get_llm_model",
10
  "register_llm_model",
11
  "LLM_MODEL_REGISTRY",
12
+ ]
modules/llm/gemini.py CHANGED
@@ -1,23 +1,31 @@
 
 
 
 
1
  import google.generativeai as genai
2
- from abc import ABC, abstractmethod
3
  from .base import AbstractLLMModel
4
  from .registry import register_llm_model
5
 
6
 
7
-
8
- GEMINI_TOKEN = os.getenv("GEMINI_API_KEY")
9
 
10
 
11
  @register_llm_model("gemini-2.5-flash")
12
- class GeminiModel(AbstractLLMModel):
13
  def __init__(
14
  self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
15
  ):
 
 
 
 
 
 
16
  super().__init__(model_id=model_id, **kwargs)
17
- genai.configure(api_key=GEMINI_API_KEY)
18
  self.model = genai.GenerativeModel(model_id)
19
 
20
  def generate(self, prompt: str, **kwargs) -> str:
21
  response = self.model.generate_content(prompt, **kwargs)
22
  return response.text
23
-
 
1
+ import os
2
+
3
+ os.environ["XDG_CACHE_HOME"] = "./.cache" # must be set before importing google.generativeai
4
+
5
  import google.generativeai as genai
6
+
7
  from .base import AbstractLLMModel
8
  from .registry import register_llm_model
9
 
10
 
11
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
 
12
 
13
 
14
  @register_llm_model("gemini-2.5-flash")
15
+ class GeminiLLM(AbstractLLMModel):
16
  def __init__(
17
  self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
18
  ):
19
+ if os.environ.get("XDG_CACHE_HOME") != cache_dir:
20
+ raise RuntimeError(
21
+ f"XDG_CACHE_HOME must be set to '{cache_dir}' before importing this module."
22
+ )
23
+ if not GOOGLE_API_KEY:
24
+ raise ValueError("Please set the GOOGLE_API_KEY environment variable to use Gemini.")
25
  super().__init__(model_id=model_id, **kwargs)
26
+ genai.configure(api_key=GOOGLE_API_KEY)
27
  self.model = genai.GenerativeModel(model_id)
28
 
29
  def generate(self, prompt: str, **kwargs) -> str:
30
  response = self.model.generate_content(prompt, **kwargs)
31
  return response.text
 
tests/test_llm_infer.py CHANGED
@@ -11,6 +11,7 @@ if __name__ == "__main__":
11
  # "deepseek-ai/DeepSeek-R1-0528",
12
  # "openai-community/gpt2-xl",
13
  # "google/gemma-2-2b",
 
14
  ]
15
  for model_id in supported_llms:
16
  try:
 
11
  # "deepseek-ai/DeepSeek-R1-0528",
12
  # "openai-community/gpt2-xl",
13
  # "google/gemma-2-2b",
14
+ # "gemini-2.5-flash",
15
  ]
16
  for model_id in supported_llms:
17
  try: