nyasukun commited on
Commit
6f6b422
·
1 Parent(s): b711b66
Files changed (1) hide show
  1. app.py +44 -22
app.py CHANGED
@@ -6,6 +6,7 @@ from enum import Enum, auto
6
  import torch
7
  from transformers import AutoTokenizer, pipeline
8
  import spaces
 
9
 
10
  # ロガーの設定
11
  logging.basicConfig(
@@ -114,12 +115,15 @@ def generate_text_local(model_path, text):
114
  """ローカルモデルでのテキスト生成"""
115
  try:
116
  logger.info(f"Running local text generation with {model_path}")
117
- outputs = pipelines[model_path](
 
 
118
  text,
119
  max_new_tokens=40,
120
  do_sample=False,
121
  num_return_sequences=1
122
  )
 
123
  return outputs[0]["generated_text"]
124
  except Exception as e:
125
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
@@ -144,7 +148,10 @@ def classify_text_local(model_path, text):
144
  """ローカルモデルでのテキスト分類"""
145
  try:
146
  logger.info(f"Running local classification with {model_path}")
147
- result = pipelines[model_path](text)
 
 
 
148
  return str(result)
149
  except Exception as e:
150
  logger.error(f"Error in local classification with {model_path}: {str(e)}")
@@ -164,28 +171,43 @@ def classify_text_api(model_id, text):
164
  def handle_invoke(text, selected_types):
165
  """Invokeボタンのハンドラ"""
166
  results = []
 
167
 
168
- # テキスト生成モデルの実行
169
- for model in TEXT_GENERATION_MODELS:
170
- if model["type"] in selected_types:
171
- if model["type"] == LOCAL:
172
- result = generate_text_local(model["model_path"], text)
173
- else: # api
174
- result = generate_text_api(model["model_id"], text)
175
- results.append(f"{model['name']}: {result}")
176
-
177
- # 分類モデルの実行
178
- for model in CLASSIFICATION_MODELS:
179
- if model["type"] in selected_types:
180
- if model["type"] == LOCAL:
181
- result = classify_text_local(model["model_path"], text)
182
- else: # api
183
- result = classify_text_api(model["model_id"], text)
184
- results.append(f"{model['name']}: {result}")
185
 
186
- # 結果リストの長さを調整
187
- while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
188
- results.append("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  return results
191
 
 
6
  import torch
7
  from transformers import AutoTokenizer, pipeline
8
  import spaces
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
 
11
  # ロガーの設定
12
  logging.basicConfig(
 
115
  """ローカルモデルでのテキスト生成"""
116
  try:
117
  logger.info(f"Running local text generation with {model_path}")
118
+ pipeline = pipelines[model_path]
119
+ pipeline.to("cuda")
120
+ outputs = pipeline(
121
  text,
122
  max_new_tokens=40,
123
  do_sample=False,
124
  num_return_sequences=1
125
  )
126
+ pipeline.to("cpu")
127
  return outputs[0]["generated_text"]
128
  except Exception as e:
129
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
 
148
  """ローカルモデルでのテキスト分類"""
149
  try:
150
  logger.info(f"Running local classification with {model_path}")
151
+ pipeline = pipelines[model_path]
152
+ pipeline.to("cuda")
153
+ result = pipeline(text)
154
+ pipeline.to("cpu")
155
  return str(result)
156
  except Exception as e:
157
  logger.error(f"Error in local classification with {model_path}: {str(e)}")
 
171
  def handle_invoke(text, selected_types):
172
  """Invokeボタンのハンドラ"""
173
  results = []
174
+ futures_to_model = {} # 各futureとモデルを紐づけるための辞書
175
 
176
+ with ThreadPoolExecutor(max_workers=len(selected_types)) as executor:
177
+ futures = []
178
+
179
+ # テキスト生成モデルの実行
180
+ for model in TEXT_GENERATION_MODELS:
181
+ if model["type"] in selected_types:
182
+ if model["type"] == LOCAL:
183
+ future = executor.submit(generate_text_local, model["model_path"], text)
184
+ futures.append(future)
185
+ futures_to_model[future] = model
186
+ else: # api
187
+ future = executor.submit(generate_text_api, model["model_id"], text)
188
+ futures.append(future)
189
+ futures_to_model[future] = model
 
 
 
190
 
191
+ # 分類モデルの実行
192
+ for model in CLASSIFICATION_MODELS:
193
+ if model["type"] in selected_types:
194
+ if model["type"] == LOCAL:
195
+ future = executor.submit(classify_text_local, model["model_path"], text)
196
+ futures.append(future)
197
+ futures_to_model[future] = model
198
+ else: # api
199
+ future = executor.submit(classify_text_api, model["model_id"], text)
200
+ futures.append(future)
201
+ futures_to_model[future] = model
202
+
203
+ # 結果の収集(モデルの順序を保持)
204
+ all_models = TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS
205
+ results = [""] * len(all_models) # 事前に結果リストを初期化
206
+
207
+ for future in as_completed(futures):
208
+ model = futures_to_model[future]
209
+ model_index = all_models.index(model)
210
+ results[model_index] = future.result()
211
 
212
  return results
213