Han Jionghao commited on
Commit
bf79c76
·
unverified ·
2 Parent(s): 0e589bd 6843cf8

Merge pull request #2 from Masao-Someki/feature/add_gemini

Browse files
modules/llm/__init__.py CHANGED
@@ -2,10 +2,11 @@ from .base import AbstractLLMModel
2
  from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
  from .hf_pipeline import HFTextGenerationLLM
4
  from .qwen import QwenLLM
 
5
 
6
  __all__ = [
7
  "AbstractLLMModel",
8
  "get_llm_model",
9
  "register_llm_model",
10
  "LLM_MODEL_REGISTRY",
11
- ]
 
2
  from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
  from .hf_pipeline import HFTextGenerationLLM
4
  from .qwen import QwenLLM
5
+ from .gemini import GeminiLLM
6
 
7
  __all__ = [
8
  "AbstractLLMModel",
9
  "get_llm_model",
10
  "register_llm_model",
11
  "LLM_MODEL_REGISTRY",
12
+ ]
modules/llm/gemini.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["XDG_CACHE_HOME"] = "./.cache" # must be set before importing google.generativeai
4
+
5
+ import google.generativeai as genai
6
+
7
+ from .base import AbstractLLMModel
8
+ from .registry import register_llm_model
9
+
10
+
11
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
12
+
13
+
14
+ @register_llm_model("gemini-2.5-flash")
15
+ class GeminiLLM(AbstractLLMModel):
16
+ def __init__(
17
+ self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
18
+ ):
19
+ if os.environ.get("XDG_CACHE_HOME") != cache_dir:
20
+ raise RuntimeError(
21
+ f"XDG_CACHE_HOME must be set to '{cache_dir}' before importing this module."
22
+ )
23
+ if not GOOGLE_API_KEY:
24
+ raise ValueError("Please set the GOOGLE_API_KEY environment variable to use Gemini.")
25
+ super().__init__(model_id=model_id, **kwargs)
26
+ genai.configure(api_key=GOOGLE_API_KEY)
27
+ self.model = genai.GenerativeModel(model_id)
28
+
29
+ def generate(self, prompt: str, **kwargs) -> str:
30
+ response = self.model.generate_content(prompt, **kwargs)
31
+ return response.text
requirements.txt CHANGED
@@ -12,9 +12,9 @@ pykakasi
12
  basic-pitch[onnx]
13
  audiobox_aesthetics
14
  transformers
15
- s3prl
16
  zhconv
17
  git+https://github.com/sea-turt1e/kanjiconv
18
  soundfile
19
  PyYAML
20
  gradio
 
 
12
  basic-pitch[onnx]
13
  audiobox_aesthetics
14
  transformers
 
15
  zhconv
16
  git+https://github.com/sea-turt1e/kanjiconv
17
  soundfile
18
  PyYAML
19
  gradio
20
+ google-generativeai
tests/test_llm_infer.py CHANGED
@@ -11,6 +11,7 @@ if __name__ == "__main__":
11
  # "deepseek-ai/DeepSeek-R1-0528",
12
  # "openai-community/gpt2-xl",
13
  # "google/gemma-2-2b",
 
14
  ]
15
  for model_id in supported_llms:
16
  try:
 
11
  # "deepseek-ai/DeepSeek-R1-0528",
12
  # "openai-community/gpt2-xl",
13
  # "google/gemma-2-2b",
14
+ # "gemini-2.5-flash",
15
  ]
16
  for model_id in supported_llms:
17
  try: