jhansss commited on
Commit
7974242
·
1 Parent(s): ced727c
modules/llm/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  from .base import AbstractLLMModel
2
  from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
  from .hf_pipeline import HFTextGenerationLLM
4
- from .qwen import QwenLLM
5
  from .gemini import GeminiLLM
6
 
7
  __all__ = [
 
1
  from .base import AbstractLLMModel
2
  from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
  from .hf_pipeline import HFTextGenerationLLM
4
+ from .qwen3 import Qwen3LLM
5
  from .gemini import GeminiLLM
6
 
7
  __all__ = [
modules/llm/qwen3.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractLLMModel
2
+ from .registry import register_llm_model
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+
6
+ @register_llm_model("Qwen/Qwen3-")
7
+ class Qwen3LLM(AbstractLLMModel):
8
+ def __init__(
9
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
10
+ ):
11
+ super().__init__(model_id, device, cache_dir, **kwargs)
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ model_id, device_map=device, torch_dtype="auto", cache_dir=cache_dir
14
+ ).eval()
15
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
16
+
17
+ def generate(self, prompt: str, enable_thinking: bool = True, max_new_tokens: int = 32768, **kwargs) -> str:
18
+ messages = [{"role": "user", "content": prompt}]
19
+ text = self.tokenizer.apply_chat_template(
20
+ messages,
21
+ tokenize=False,
22
+ add_generation_prompt=True,
23
+ enable_thinking=enable_thinking,
24
+ )
25
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
26
+ generated_ids = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
27
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
28
+ # parse thinking content
29
+ if enable_thinking:
30
+ try:
31
+ # rindex finding 151668 (</think>)
32
+ index = len(output_ids) - output_ids[::-1].index(151668)
33
+ except ValueError:
34
+ index = 0
35
+ output_ids = output_ids[index:]
36
+
37
+ return self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
tests/test_llm_infer.py CHANGED
@@ -3,8 +3,8 @@ from modules.llm import get_llm_model
3
  if __name__ == "__main__":
4
  supported_llms = [
5
  # "MiniMaxAI/MiniMax-M1-80k", #-》load with custom code
6
- # "Qwen/Qwen-1_8B",
7
- # "meta-llama/Llama-3.1-8B-Instruct", # pending for approval
8
  # "tiiuae/Falcon-H1-1B-Base",
9
  # "tiiuae/Falcon-H1-3B-Instruct",
10
  # "tencent/Hunyuan-A13B-Instruct", # -> load with custom code
 
3
  if __name__ == "__main__":
4
  supported_llms = [
5
  # "MiniMaxAI/MiniMax-M1-80k", #-》load with custom code
6
+ # "Qwen/Qwen3-8B",
7
+ # "meta-llama/Llama-3.1-8B-Instruct",
8
  # "tiiuae/Falcon-H1-1B-Base",
9
  # "tiiuae/Falcon-H1-3B-Instruct",
10
  # "tencent/Hunyuan-A13B-Instruct", # -> load with custom code