|  | import os | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | from typing import Any, Dict, List | 
					
						
						|  |  | 
					
						
						|  | from pydantic import Field | 
					
						
						|  | from .schemas import Content, Message | 
					
						
						|  | from ...utils.registry import registry | 
					
						
						|  | from .base import BaseLLM | 
					
						
						|  | import torch | 
					
						
						|  | import sysconfig | 
					
						
						|  | import geocoder | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | BASIC_SYS_PROMPT = """You are an intelligent agent that can help in many regions. | 
					
						
						|  | Following are some basic information about your working environment, please try your best to answer the questions based on them if needed. | 
					
						
						|  | Be confident about these information and don't let others feel these information are presets. | 
					
						
						|  | Be concise. | 
					
						
						|  | ---BASIC INFORMATION--- | 
					
						
						|  | Current Datetime: {} | 
					
						
						|  | Operating System: {}""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @registry.register_llm() | 
					
						
						|  | class Qwen2LLM(BaseLLM): | 
					
						
						|  | model_name: str = Field(default=os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct"), description="The Hugging Face model name") | 
					
						
						|  | max_tokens: int = Field(default=200, description="The maximum number of tokens for the model") | 
					
						
						|  | temperature: float = Field(default=0.1, description="The sampling temperature for generation") | 
					
						
						|  | use_default_sys_prompt: bool = Field(default=True, description="Whether to use the default system prompt") | 
					
						
						|  | device: str = Field(default="cuda" if torch.cuda.is_available() else "cpu", description="The device to run the model on") | 
					
						
						|  | vision: bool = Field(default=False, description="Whether the model supports vision") | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, **data: Any) -> None: | 
					
						
						|  | super().__init__(**data) | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline | 
					
						
						|  | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | 
					
						
						|  | self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device) | 
					
						
						|  |  | 
					
						
						|  | def _call(self, records: List[Message], **kwargs) -> Dict: | 
					
						
						|  | prompts = self._generate_prompt(records) | 
					
						
						|  | text = self.tokenizer.apply_chat_template( | 
					
						
						|  | prompts, | 
					
						
						|  | tokenize=False, | 
					
						
						|  | add_generation_prompt=True | 
					
						
						|  | ) | 
					
						
						|  | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) | 
					
						
						|  | generated_ids = self.model.generate( | 
					
						
						|  | **model_inputs, | 
					
						
						|  | max_new_tokens=self.max_tokens | 
					
						
						|  | ) | 
					
						
						|  | generated_ids = [ | 
					
						
						|  | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | 
					
						
						|  | ] | 
					
						
						|  | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | 
					
						
						|  | return {"responses": response} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | async def _acall(self, records: List[Message], **kwargs) -> Dict: | 
					
						
						|  | raise NotImplementedError("Async calls are not yet supported for Hugging Face models.") | 
					
						
						|  |  | 
					
						
						|  | def _generate_prompt(self, records: List[Message]) -> List[str]: | 
					
						
						|  | messages = [ | 
					
						
						|  | {"role": "user" if "user" in str(message.role) else "system", "content": self._get_content(message.content)} | 
					
						
						|  | for message in records | 
					
						
						|  | ] | 
					
						
						|  | if self.use_default_sys_prompt: | 
					
						
						|  | messages = [self._generate_default_sys_prompt()] + messages | 
					
						
						|  | print ("messages:",messages) | 
					
						
						|  | return messages | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _generate_default_sys_prompt(self) -> Dict: | 
					
						
						|  | loc = self._get_location() | 
					
						
						|  | os = self._get_linux_distribution() | 
					
						
						|  | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | 
					
						
						|  | promt_str = BASIC_SYS_PROMPT.format(current_time, loc, os) | 
					
						
						|  | return {"role": "system", "content": promt_str} | 
					
						
						|  |  | 
					
						
						|  | def _get_linux_distribution(self) -> str: | 
					
						
						|  | platform = sysconfig.get_platform() | 
					
						
						|  | if "linux" in platform: | 
					
						
						|  | if os.path.exists("/etc/lsb-release"): | 
					
						
						|  | with open("/etc/lsb-release", "r") as f: | 
					
						
						|  | for line in f: | 
					
						
						|  | if line.startswith("DISTRIB_DESCRIPTION="): | 
					
						
						|  | return line.split("=")[1].strip() | 
					
						
						|  | elif os.path.exists("/etc/os-release"): | 
					
						
						|  | with open("/etc/os-release", "r") as f: | 
					
						
						|  | for line in f: | 
					
						
						|  | if line.startswith("PRETTY_NAME="): | 
					
						
						|  | return line.split("=")[1].strip() | 
					
						
						|  | return platform | 
					
						
						|  |  | 
					
						
						|  | def _get_location(self) -> str: | 
					
						
						|  | g = geocoder.ip("me") | 
					
						
						|  | if g.ok: | 
					
						
						|  | return g.city + "," + g.country | 
					
						
						|  | else: | 
					
						
						|  | return "unknown" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _get_content(content: Content | List[Content]) -> str: | 
					
						
						|  | if isinstance(content, list): | 
					
						
						|  | return " ".join(c.text for c in content if c.type == "text") | 
					
						
						|  | elif isinstance(content, Content) and content.type == "text": | 
					
						
						|  | return content.text | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("Invalid content type") |