nyasukun commited on
Commit
0a54f6b
·
1 Parent(s): 9ace18a
Files changed (1) hide show
  1. app.py +59 -3
app.py CHANGED
@@ -64,11 +64,46 @@ class LocalModelManager:
64
  self.models = {}
65
  self.tokenizers = {}
66
  self.pipelines = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  async def load_model(self, model_path: str, task: str = "text-generation"):
69
- """モデルの遅延ロード"""
70
- if model_path not in self.models:
71
- logger.info(f"Loading model: {model_path}")
72
  try:
73
  self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
74
 
@@ -140,6 +175,7 @@ class ModelManager:
140
  self.api_clients = {}
141
  self.local_manager = LocalModelManager()
142
  self._initialize_clients()
 
143
 
144
  def _initialize_clients(self):
145
  """Inference APIクライアントの初期化"""
@@ -150,6 +186,26 @@ class ModelManager:
150
  token=True # これによりHFトークンを使用
151
  )
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
154
  """テキスト生成モデルの実行"""
155
  results = []
 
64
  self.models = {}
65
  self.tokenizers = {}
66
  self.pipelines = {}
67
+
68
+ def preload_models(self, model_paths, tasks=None):
69
+ """アプリケーション起動時にモデルを事前ロード"""
70
+ if tasks is None:
71
+ tasks = {} # デフォルトは空の辞書
72
+
73
+ logger.info("Preloading models at application startup...")
74
+ for model_path in model_paths:
75
+ task = tasks.get(model_path, "text-generation")
76
+ try:
77
+ logger.info(f"Preloading model: {model_path} for task: {task}")
78
+ self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
79
+
80
+ if task == "text-generation":
81
+ self.pipelines[model_path] = pipeline(
82
+ "text-generation",
83
+ model=model_path,
84
+ tokenizer=self.tokenizers[model_path],
85
+ torch_dtype=torch.bfloat16,
86
+ trust_remote_code=True,
87
+ device_map="auto"
88
+ )
89
+ else: # classification
90
+ self.pipelines[model_path] = pipeline(
91
+ "text-classification",
92
+ model=model_path,
93
+ tokenizer=self.tokenizers[model_path],
94
+ torch_dtype=torch.bfloat16,
95
+ trust_remote_code=True,
96
+ device_map="auto"
97
+ )
98
+ logger.info(f"Model preloaded successfully: {model_path}")
99
+ except Exception as e:
100
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
101
+ # 続行するが、エラーをログに記録
102
 
103
  async def load_model(self, model_path: str, task: str = "text-generation"):
104
+ """モデルが既にロードされているか確認し、なければロード"""
105
+ if model_path not in self.pipelines:
106
+ logger.info(f"Loading model on demand: {model_path}")
107
  try:
108
  self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
109
 
 
175
  self.api_clients = {}
176
  self.local_manager = LocalModelManager()
177
  self._initialize_clients()
178
+ self._preload_local_models()
179
 
180
  def _initialize_clients(self):
181
  """Inference APIクライアントの初期化"""
 
186
  token=True # これによりHFトークンを使用
187
  )
188
 
189
+ def _preload_local_models(self):
190
+ """ローカルモデルの事前ロード"""
191
+ models_to_preload = []
192
+ tasks = {}
193
+
194
+ # テキスト生成モデルの追加
195
+ for model in TEXT_GENERATION_MODELS:
196
+ if model.type == ModelType.LOCAL and model.model_path:
197
+ models_to_preload.append(model.model_path)
198
+ tasks[model.model_path] = "text-generation"
199
+
200
+ # 分類モデルの追加
201
+ for model in CLASSIFICATION_MODELS:
202
+ if model.type == ModelType.LOCAL and model.model_path:
203
+ models_to_preload.append(model.model_path)
204
+ tasks[model.model_path] = "text-classification"
205
+
206
+ # 事前ロード実行
207
+ self.local_manager.preload_models(models_to_preload, tasks)
208
+
209
  async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
210
  """テキスト生成モデルの実行"""
211
  results = []