app.py
CHANGED
@@ -64,11 +64,46 @@ class LocalModelManager:
|
|
64 |
self.models = {}
|
65 |
self.tokenizers = {}
|
66 |
self.pipelines = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
async def load_model(self, model_path: str, task: str = "text-generation"):
|
69 |
-
"""
|
70 |
-
if model_path not in self.
|
71 |
-
logger.info(f"Loading model: {model_path}")
|
72 |
try:
|
73 |
self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
|
74 |
|
@@ -140,6 +175,7 @@ class ModelManager:
|
|
140 |
self.api_clients = {}
|
141 |
self.local_manager = LocalModelManager()
|
142 |
self._initialize_clients()
|
|
|
143 |
|
144 |
def _initialize_clients(self):
|
145 |
"""Inference APIクライアントの初期化"""
|
@@ -150,6 +186,26 @@ class ModelManager:
|
|
150 |
token=True # これによりHFトークンを使用
|
151 |
)
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
154 |
"""テキスト生成モデルの実行"""
|
155 |
results = []
|
|
|
64 |
self.models = {}
|
65 |
self.tokenizers = {}
|
66 |
self.pipelines = {}
|
67 |
+
|
68 |
+
def preload_models(self, model_paths, tasks=None):
|
69 |
+
"""アプリケーション起動時にモデルを事前ロード"""
|
70 |
+
if tasks is None:
|
71 |
+
tasks = {} # デフォルトは空の辞書
|
72 |
+
|
73 |
+
logger.info("Preloading models at application startup...")
|
74 |
+
for model_path in model_paths:
|
75 |
+
task = tasks.get(model_path, "text-generation")
|
76 |
+
try:
|
77 |
+
logger.info(f"Preloading model: {model_path} for task: {task}")
|
78 |
+
self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
|
79 |
+
|
80 |
+
if task == "text-generation":
|
81 |
+
self.pipelines[model_path] = pipeline(
|
82 |
+
"text-generation",
|
83 |
+
model=model_path,
|
84 |
+
tokenizer=self.tokenizers[model_path],
|
85 |
+
torch_dtype=torch.bfloat16,
|
86 |
+
trust_remote_code=True,
|
87 |
+
device_map="auto"
|
88 |
+
)
|
89 |
+
else: # classification
|
90 |
+
self.pipelines[model_path] = pipeline(
|
91 |
+
"text-classification",
|
92 |
+
model=model_path,
|
93 |
+
tokenizer=self.tokenizers[model_path],
|
94 |
+
torch_dtype=torch.bfloat16,
|
95 |
+
trust_remote_code=True,
|
96 |
+
device_map="auto"
|
97 |
+
)
|
98 |
+
logger.info(f"Model preloaded successfully: {model_path}")
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Error preloading model {model_path}: {str(e)}")
|
101 |
+
# 続行するが、エラーをログに記録
|
102 |
|
103 |
async def load_model(self, model_path: str, task: str = "text-generation"):
|
104 |
+
"""モデルが既にロードされているか確認し、なければロード"""
|
105 |
+
if model_path not in self.pipelines:
|
106 |
+
logger.info(f"Loading model on demand: {model_path}")
|
107 |
try:
|
108 |
self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
|
109 |
|
|
|
175 |
self.api_clients = {}
|
176 |
self.local_manager = LocalModelManager()
|
177 |
self._initialize_clients()
|
178 |
+
self._preload_local_models()
|
179 |
|
180 |
def _initialize_clients(self):
|
181 |
"""Inference APIクライアントの初期化"""
|
|
|
186 |
token=True # これによりHFトークンを使用
|
187 |
)
|
188 |
|
189 |
+
def _preload_local_models(self):
|
190 |
+
"""ローカルモデルの事前ロード"""
|
191 |
+
models_to_preload = []
|
192 |
+
tasks = {}
|
193 |
+
|
194 |
+
# テキスト生成モデルの追加
|
195 |
+
for model in TEXT_GENERATION_MODELS:
|
196 |
+
if model.type == ModelType.LOCAL and model.model_path:
|
197 |
+
models_to_preload.append(model.model_path)
|
198 |
+
tasks[model.model_path] = "text-generation"
|
199 |
+
|
200 |
+
# 分類モデルの追加
|
201 |
+
for model in CLASSIFICATION_MODELS:
|
202 |
+
if model.type == ModelType.LOCAL and model.model_path:
|
203 |
+
models_to_preload.append(model.model_path)
|
204 |
+
tasks[model.model_path] = "text-classification"
|
205 |
+
|
206 |
+
# 事前ロード実行
|
207 |
+
self.local_manager.preload_models(models_to_preload, tasks)
|
208 |
+
|
209 |
async def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
|
210 |
"""テキスト生成モデルの実行"""
|
211 |
results = []
|