nyasukun commited on
Commit
0da498f
·
1 Parent(s): 9975206
Files changed (1) hide show
  1. app.py +225 -332
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import gradio as gr
 
 
2
  from huggingface_hub import InferenceClient
3
- from typing import List, Dict, Optional, Union
4
  import logging
5
- from dataclasses import dataclass
6
- from enum import Enum, auto
7
- import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline
9
  import spaces
10
 
11
  # ロガーの設定
@@ -15,99 +12,67 @@ logging.basicConfig(
15
  )
16
  logger = logging.getLogger(__name__)
17
 
18
- # モデルの型定義
19
- class ModelType(Enum):
20
- LOCAL = "local"
21
- INFERENCE_API = "inference_api"
22
-
23
- @dataclass
24
- class ModelConfig:
25
- name: str
26
- description: str
27
- type: ModelType
28
- model_id: Optional[str] = None
29
- model_path: Optional[str] = None
30
-
31
- # モデル定義を拡充
32
  TEXT_GENERATION_MODELS = [
33
- ModelConfig(
34
- name="Zephyr-7B",
35
- description="Specialized in understanding context and nuance",
36
- type=ModelType.INFERENCE_API,
37
- model_id="HuggingFaceH4/zephyr-7b-beta"
38
- ),
39
- ModelConfig(
40
- name="Llama-2",
41
- description="Known for its robust performance in content analysis",
42
- type=ModelType.LOCAL,
43
- model_path="meta-llama/Llama-2-7b-hf"
44
- ),
45
- ModelConfig(
46
- name="Mistral-7B",
47
- description="Offers precise and detailed text evaluation",
48
- type=ModelType.LOCAL,
49
- model_path="mistralai/Mistral-7B-v0.1"
50
- )
51
  ]
52
 
53
  CLASSIFICATION_MODELS = [
54
- ModelConfig(
55
- name="Toxic-BERT",
56
- description="Fine-tuned for toxic content detection",
57
- type=ModelType.LOCAL,
58
- model_path="unitary/toxic-bert"
59
- )
60
  ]
61
 
62
- class LocalModelManager:
 
63
  def __init__(self):
64
- self.models = {}
65
  self.tokenizers = {}
66
  self.pipelines = {}
67
-
68
- def preload_models(self, model_paths, tasks=None):
69
- """アプリケーション起動時にモデルを事前ロード"""
70
- if tasks is None:
71
- tasks = {} # デフォルトは空の辞書
72
-
73
- logger.info("Preloading models at application startup...")
74
- for model_path in model_paths:
75
- task = tasks.get(model_path, "text-generation")
76
- try:
77
- logger.info(f"Preloading model: {model_path} for task: {task}")
78
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
79
-
80
- if task == "text-generation":
81
- self.pipelines[model_path] = pipeline(
82
- "text-generation",
83
- model=model_path,
84
- tokenizer=self.tokenizers[model_path],
85
- torch_dtype=torch.bfloat16,
86
- trust_remote_code=True,
87
- device_map="auto"
88
- )
89
- else: # classification
90
- self.pipelines[model_path] = pipeline(
91
- "text-classification",
92
- model=model_path,
93
- tokenizer=self.tokenizers[model_path],
94
- torch_dtype=torch.bfloat16,
95
- trust_remote_code=True,
96
- device_map="auto"
97
- )
98
- logger.info(f"Model preloaded successfully: {model_path}")
99
- except Exception as e:
100
- logger.error(f"Error preloading model {model_path}: {str(e)}")
101
- # 続行するが、エラーをログに記録
102
 
103
- def load_model(self, model_path: str, task: str = "text-generation"):
104
- """モデルが既にロードされているか確認し、なければロード"""
105
- if model_path not in self.pipelines:
106
- logger.info(f"Loading model on demand: {model_path}")
107
- try:
108
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
109
-
110
- if task == "text-generation":
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  self.pipelines[model_path] = pipeline(
112
  "text-generation",
113
  model=model_path,
@@ -116,7 +81,17 @@ class LocalModelManager:
116
  trust_remote_code=True,
117
  device_map="auto"
118
  )
119
- else: # classification
 
 
 
 
 
 
 
 
 
 
120
  self.pipelines[model_path] = pipeline(
121
  "text-classification",
122
  model=model_path,
@@ -125,277 +100,195 @@ class LocalModelManager:
125
  trust_remote_code=True,
126
  device_map="auto"
127
  )
128
-
129
- logger.info(f"Model loaded successfully: {model_path}")
130
- except Exception as e:
131
- logger.error(f"Error loading model {model_path}: {str(e)}")
132
- raise
133
-
134
- @spaces.GPU
135
- def _generate_text_sync(self, pipeline, text: str) -> str:
136
- """同期的なテキスト生成の実行"""
137
- outputs = pipeline(
138
- text,
139
- max_new_tokens=100,
140
- do_sample=False,
141
- num_return_sequences=1
142
- )
143
- return outputs[0]["generated_text"]
144
 
145
- def generate_text(self, model_path: str, text: str) -> str:
146
- """テキスト生成の実行"""
147
- if model_path not in self.pipelines:
148
- self.load_model(model_path, "text-generation")
149
-
150
  try:
151
- return self._generate_text_sync(self.pipelines[model_path], text)
 
 
 
 
 
 
 
152
  except Exception as e:
153
- logger.error(f"Error in text generation with {model_path}: {str(e)}")
154
- raise
155
 
156
- @spaces.GPU
157
- def _classify_text_sync(self, pipeline, text: str) -> str:
158
- """同期的なテキスト分類の実行"""
159
- result = pipeline(text)
160
- return str(result)
161
-
162
- def classify_text(self, model_path: str, text: str) -> str:
163
- """テキスト分類の実行"""
164
- if model_path not in self.pipelines:
165
- self.load_model(model_path, "text-classification")
166
-
167
  try:
168
- return self._classify_text_sync(self.pipelines[model_path], text)
 
 
 
 
 
 
169
  except Exception as e:
170
- logger.error(f"Error in classification with {model_path}: {str(e)}")
171
- raise
172
 
173
- class ModelManager:
174
- def __init__(self):
175
- self.api_clients = {}
176
- self.local_manager = LocalModelManager()
177
- self._initialize_clients()
178
- self._preload_local_models()
 
 
 
179
 
180
- def _initialize_clients(self):
181
- """Inference APIクライアントの初期化"""
182
- for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
183
- if model.type == ModelType.INFERENCE_API and model.model_id:
184
- self.api_clients[model.model_id] = InferenceClient(
185
- model.model_id,
186
- token=True # これによりHFトークンを使用
187
- )
 
188
 
189
- def _preload_local_models(self):
190
- """ローカルモデルの事前ロード"""
191
- models_to_preload = []
192
- tasks = {}
193
 
194
- # テキスト生成モデルの追加
195
  for model in TEXT_GENERATION_MODELS:
196
- if model.type == ModelType.LOCAL and model.model_path:
197
- models_to_preload.append(model.model_path)
198
- tasks[model.model_path] = "text-generation"
 
 
 
 
199
 
200
- # 分類モデルの追加
201
  for model in CLASSIFICATION_MODELS:
202
- if model.type == ModelType.LOCAL and model.model_path:
203
- models_to_preload.append(model.model_path)
204
- tasks[model.model_path] = "text-classification"
 
 
 
 
 
 
 
 
205
 
206
- # 事前ロード実行
207
- self.local_manager.preload_models(models_to_preload, tasks)
208
-
209
- def run_text_generation(self, text: str, selected_types: List[str]) -> List[str]:
210
- """テキスト生成モデルの実行"""
211
- results = []
212
- for model in TEXT_GENERATION_MODELS:
213
- if model.type.value in selected_types:
214
- try:
215
- if model.type == ModelType.INFERENCE_API:
216
- logger.info(f"Running API text generation: {model.name}")
217
- response = self.api_clients[model.model_id].text_generation(
218
- text, max_new_tokens=100, temperature=0.7
219
- )
220
- results.append(f"{model.name}: {response}")
221
- else:
222
- logger.info(f"Running local text generation: {model.name}")
223
- response = self.local_manager.generate_text(model.model_path, text)
224
- results.append(f"{model.name}: {response}")
225
- except Exception as e:
226
- logger.error(f"Error in {model.name}: {str(e)}")
227
- results.append(f"{model.name}: Error - {str(e)}")
228
- return results
229
-
230
- def run_classification(self, text: str, selected_types: List[str]) -> List[str]:
231
- """分類モデルの実行"""
232
- results = []
233
- for model in CLASSIFICATION_MODELS:
234
- if model.type.value in selected_types:
235
- try:
236
- if model.type == ModelType.INFERENCE_API:
237
- logger.info(f"Running API classification: {model.name}")
238
- response = self.api_clients[model.model_id].text_classification(text)
239
- results.append(f"{model.name}: {response}")
240
- else:
241
- logger.info(f"Running local classification: {model.name}")
242
- response = self.local_manager.classify_text(model.model_path, text)
243
- results.append(f"{model.name}: {response}")
244
- except Exception as e:
245
- logger.error(f"Error in {model.name}: {str(e)}")
246
- results.append(f"{model.name}: Error - {str(e)}")
247
  return results
248
 
249
- class UIComponents:
250
- def __init__(self):
251
- self.input_text = None
252
- self.filter_checkboxes = None
253
- self.invoke_button = None
254
- self.gen_model_outputs = []
255
- self.class_model_outputs = []
256
- self.community_output = None
257
-
258
- def create_header(self):
259
- """ヘッダーセクションの作成"""
260
- return gr.Markdown("""
261
- # Toxic Eye
262
- This system evaluates the toxicity level of input text using multiple approaches.
263
- """)
264
-
265
- def create_input_section(self):
266
- """入力セクションの作成"""
267
- with gr.Row():
268
- self.input_text = gr.Textbox(
269
- label="Input Text",
270
- placeholder="Enter text to analyze...",
271
- lines=3
272
- )
273
-
274
- def create_filter_section(self):
275
- """フィルターセクションの作成"""
276
- with gr.Row():
277
- self.filter_checkboxes = gr.CheckboxGroup(
278
- choices=[t.value for t in ModelType],
279
- value=[t.value for t in ModelType],
280
- label="Filter Models",
281
- info="Choose which types of models to display",
282
- interactive=True
283
- )
284
-
285
- def create_invoke_button(self):
286
- """Invokeボタンの作成"""
287
- with gr.Row():
288
- self.invoke_button = gr.Button(
289
- "Invoke Selected Models",
290
- variant="primary",
291
- size="lg"
292
- )
293
 
294
- def create_model_grid(self, models: List[ModelConfig]) -> List[Dict]:
295
- """モデルグリッドの作成"""
296
- outputs = []
297
- with gr.Column() as container:
298
- for i in range(0, len(models), 2):
299
- with gr.Row() as row:
300
- for j in range(min(2, len(models) - i)):
301
- model = models[i + j]
302
- with gr.Column():
303
- with gr.Group() as group:
304
- gr.Markdown(f"### {model.name}")
305
- gr.Markdown(f"Type: {model.type.value}")
306
- output = gr.Textbox(
307
- label="Model Output",
308
- lines=5,
309
- interactive=False,
310
- info=model.description
311
- )
312
- outputs.append({
313
- "type": model.type.value,
314
- "name": model.name,
315
- "output": output,
316
- "group": group
317
- })
318
- return outputs
319
 
320
- def create_model_tabs(self):
321
- """モデルタブの作成"""
322
- with gr.Tabs():
323
- with gr.Tab("Text Generation LLM"):
324
- self.gen_model_outputs = self.create_model_grid(TEXT_GENERATION_MODELS)
325
- with gr.Tab("Classification LLM"):
326
- self.class_model_outputs = self.create_model_grid(CLASSIFICATION_MODELS)
327
- with gr.Tab("Community (Not implemented)"):
328
- with gr.Column():
329
- self.community_output = gr.Textbox(
330
- label="Related Community Topics",
331
- lines=5,
332
- interactive=False
333
- )
334
-
335
- class ToxicityApp:
336
- def __init__(self):
337
- self.ui = UIComponents()
338
- self.model_manager = ModelManager()
339
-
340
- def update_model_visibility(self, selected_types: List[str]) -> List[gr.update]:
341
- """モデルの表示状態を更新"""
342
- logger.info(f"Updating visibility for types: {selected_types}")
343
-
344
- updates = []
345
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]:
346
- for output in outputs:
347
- visible = output["type"] in selected_types
348
- logger.info(f"Model {output['name']} (type: {output['type']}): visible = {visible}")
349
- updates.append(gr.update(visible=visible))
350
- return updates
351
-
352
- def handle_invoke(self, text: str, selected_types: List[str]) -> List[str]:
353
- """Invokeボタンのハンドラ"""
354
- gen_results = self.model_manager.run_text_generation(text, selected_types)
355
- class_results = self.model_manager.run_classification(text, selected_types)
356
-
357
- # 結果リストの長さを調整
358
- gen_results.extend([""] * (len(TEXT_GENERATION_MODELS) - len(gen_results)))
359
- class_results.extend([""] * (len(CLASSIFICATION_MODELS) - len(class_results)))
360
 
361
- return gen_results + class_results
362
-
363
  def create_ui(self):
364
  """UIの作成"""
365
  with gr.Blocks() as demo:
366
- self.ui.create_header()
367
- self.ui.create_input_section()
368
- self.ui.create_filter_section()
369
- self.ui.create_invoke_button()
370
- self.ui.create_model_tabs()
371
-
372
- # イベントハンドラの設定
373
- self.ui.filter_checkboxes.change(
374
- fn=self.update_model_visibility,
375
- inputs=[self.ui.filter_checkboxes],
376
- outputs=[
377
- output["group"]
378
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
379
- for output in outputs
380
- ]
381
- )
382
-
383
- self.ui.invoke_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  fn=self.handle_invoke,
385
- inputs=[self.ui.input_text, self.ui.filter_checkboxes],
386
- outputs=[
387
- output["output"]
388
- for outputs in [self.ui.gen_model_outputs, self.ui.class_model_outputs]
389
- for output in outputs
390
- ]
391
  )
392
-
393
  return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  def main():
396
  app = ToxicityApp()
397
- demo = app.create_ui()
398
- demo.launch()
399
 
400
  if __name__ == "__main__":
401
  main()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, pipeline
4
  from huggingface_hub import InferenceClient
 
5
  import logging
 
 
 
 
6
  import spaces
7
 
8
  # ロガーの設定
 
12
  )
13
  logger = logging.getLogger(__name__)
14
 
15
+ # モデル定義(ローカルモデルとAPIモデルの両方)
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  TEXT_GENERATION_MODELS = [
17
+ {
18
+ "name": "Llama-2",
19
+ "description": "Known for its robust performance in content analysis",
20
+ "type": "local",
21
+ "model_path": "meta-llama/Llama-2-7b-hf"
22
+ },
23
+ {
24
+ "name": "Mistral-7B",
25
+ "description": "Offers precise and detailed text evaluation",
26
+ "type": "local",
27
+ "model_path": "mistralai/Mistral-7B-v0.1"
28
+ },
29
+ {
30
+ "name": "Zephyr-7B",
31
+ "description": "Specialized in understanding context and nuance",
32
+ "type": "api",
33
+ "model_id": "HuggingFaceH4/zephyr-7b-beta"
34
+ }
35
  ]
36
 
37
  CLASSIFICATION_MODELS = [
38
+ {
39
+ "name": "Toxic-BERT",
40
+ "description": "Fine-tuned for toxic content detection",
41
+ "type": "local",
42
+ "model_path": "unitary/toxic-bert"
43
+ }
44
  ]
45
 
46
+ # GPU関連の装飾なしでクラスを定義
47
+ class ModelManager:
48
  def __init__(self):
 
49
  self.tokenizers = {}
50
  self.pipelines = {}
51
+ self.api_clients = {}
52
+ self._initialize_api_clients()
53
+ self._preload_local_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ def _initialize_api_clients(self):
56
+ """Inference APIクライアントの初期化"""
57
+ for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
58
+ if model["type"] == "api" and "model_id" in model:
59
+ logger.info(f"Initializing API client for {model['name']}")
60
+ self.api_clients[model["model_id"]] = InferenceClient(
61
+ model["model_id"],
62
+ token=True # HFトークンを使用
63
+ )
64
+
65
+ def _preload_local_models(self):
66
+ """ローカルモデルを事前ロード"""
67
+ logger.info("Preloading local models at application startup...")
68
+
69
+ # テキスト生成モデル
70
+ for model in TEXT_GENERATION_MODELS:
71
+ if model["type"] == "local" and "model_path" in model:
72
+ model_path = model["model_path"]
73
+ try:
74
+ logger.info(f"Preloading text generation model: {model_path}")
75
+ self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
76
  self.pipelines[model_path] = pipeline(
77
  "text-generation",
78
  model=model_path,
 
81
  trust_remote_code=True,
82
  device_map="auto"
83
  )
84
+ logger.info(f"Model preloaded successfully: {model_path}")
85
+ except Exception as e:
86
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
87
+
88
+ # 分類モデル
89
+ for model in CLASSIFICATION_MODELS:
90
+ if model["type"] == "local" and "model_path" in model:
91
+ model_path = model["model_path"]
92
+ try:
93
+ logger.info(f"Preloading classification model: {model_path}")
94
+ self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
95
  self.pipelines[model_path] = pipeline(
96
  "text-classification",
97
  model=model_path,
 
100
  trust_remote_code=True,
101
  device_map="auto"
102
  )
103
+ logger.info(f"Model preloaded successfully: {model_path}")
104
+ except Exception as e:
105
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ def generate_text_local(self, model_path, text):
108
+ """ローカルモデルでのテキスト生成 (GPUデコレータはクラス外部で適用)"""
 
 
 
109
  try:
110
+ logger.info(f"Running local text generation with {model_path}")
111
+ outputs = self.pipelines[model_path](
112
+ text,
113
+ max_new_tokens=100,
114
+ do_sample=False,
115
+ num_return_sequences=1
116
+ )
117
+ return outputs[0]["generated_text"]
118
  except Exception as e:
119
+ logger.error(f"Error in local text generation with {model_path}: {str(e)}")
120
+ return f"Error: {str(e)}"
121
 
122
+ def generate_text_api(self, model_id, text):
123
+ """API経由でのテキスト生成"""
 
 
 
 
 
 
 
 
 
124
  try:
125
+ logger.info(f"Running API text generation with {model_id}")
126
+ response = self.api_clients[model_id].text_generation(
127
+ text,
128
+ max_new_tokens=100,
129
+ temperature=0.7
130
+ )
131
+ return response
132
  except Exception as e:
133
+ logger.error(f"Error in API text generation with {model_id}: {str(e)}")
134
+ return f"Error: {str(e)}"
135
 
136
+ def classify_text_local(self, model_path, text):
137
+ """ローカルモデルでのテキスト分類 (GPUデコレータはクラス外部で適用)"""
138
+ try:
139
+ logger.info(f"Running local classification with {model_path}")
140
+ result = self.pipelines[model_path](text)
141
+ return str(result)
142
+ except Exception as e:
143
+ logger.error(f"Error in local classification with {model_path}: {str(e)}")
144
+ return f"Error: {str(e)}"
145
 
146
+ def classify_text_api(self, model_id, text):
147
+ """API経由でのテキスト分類"""
148
+ try:
149
+ logger.info(f"Running API classification with {model_id}")
150
+ response = self.api_clients[model_id].text_classification(text)
151
+ return str(response)
152
+ except Exception as e:
153
+ logger.error(f"Error in API classification with {model_id}: {str(e)}")
154
+ return f"Error: {str(e)}"
155
 
156
+ def run_models(self, text, selected_types):
157
+ """選択されたタイプのモデルで分析を実行"""
158
+ results = []
 
159
 
160
+ # テキスト生成モデルの実行
161
  for model in TEXT_GENERATION_MODELS:
162
+ if model["type"] in selected_types:
163
+ if model["type"] == "local":
164
+ # クラス外部でGPUデコレータが適用される前提
165
+ result = gpu_wrapper_generate(self, model["model_path"], text)
166
+ else: # api
167
+ result = self.generate_text_api(model["model_id"], text)
168
+ results.append(f"{model['name']}: {result}")
169
 
170
+ # 分類モデルの実行
171
  for model in CLASSIFICATION_MODELS:
172
+ if model["type"] in selected_types:
173
+ if model["type"] == "local":
174
+ # クラス外部でGPUデコレータが適用される前提
175
+ result = gpu_wrapper_classify(self, model["model_path"], text)
176
+ else: # api
177
+ result = self.classify_text_api(model["model_id"], text)
178
+ results.append(f"{model['name']}: {result}")
179
+
180
+ # 結果リストの長さを調整
181
+ while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
182
+ results.append("")
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  return results
185
 
186
+ # クラスのメソッドをラップしてGPU装飾子を適用
187
+ @spaces.GPU
188
+ def gpu_wrapper_generate(manager, model_path, text):
189
+ return manager.generate_text_local(model_path, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ @spaces.GPU
192
+ def gpu_wrapper_classify(manager, model_path, text):
193
+ return manager.classify_text_local(model_path, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # UIの作成と管理のためのクラス
196
+ class UIManager:
197
+ def __init__(self, model_manager):
198
+ self.model_manager = model_manager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
 
 
200
  def create_ui(self):
201
  """UIの作成"""
202
  with gr.Blocks() as demo:
203
+ # ヘッダー
204
+ gr.Markdown("""
205
+ # Toxic Eye (Class-based Version with GPU Wrappers)
206
+ This system evaluates the toxicity level of input text using both local models and Inference API.
207
+ """)
208
+
209
+ # 入力セクション
210
+ with gr.Row():
211
+ input_text = gr.Textbox(
212
+ label="Input Text",
213
+ placeholder="Enter text to analyze...",
214
+ lines=3
215
+ )
216
+
217
+ # フィルターセクション
218
+ with gr.Row():
219
+ filter_checkboxes = gr.CheckboxGroup(
220
+ choices=["local", "api"],
221
+ value=["local", "api"],
222
+ label="Filter Models",
223
+ info="Choose which types of models to use",
224
+ interactive=True
225
+ )
226
+
227
+ # 実行ボタン
228
+ with gr.Row():
229
+ invoke_button = gr.Button(
230
+ "Analyze Text",
231
+ variant="primary",
232
+ size="lg"
233
+ )
234
+
235
+ # モデル出力表示エリア
236
+ all_outputs = []
237
+
238
+ with gr.Tabs():
239
+ # テキスト生成モデルのタブ
240
+ with gr.Tab("Text Generation Models"):
241
+ for model in TEXT_GENERATION_MODELS:
242
+ with gr.Group():
243
+ gr.Markdown(f"### {model['name']} ({model['type']})")
244
+ output = gr.Textbox(
245
+ label=f"{model['name']} Output",
246
+ lines=5,
247
+ interactive=False,
248
+ info=model["description"]
249
+ )
250
+ all_outputs.append(output)
251
+
252
+ # 分類モデルのタブ
253
+ with gr.Tab("Classification Models"):
254
+ for model in CLASSIFICATION_MODELS:
255
+ with gr.Group():
256
+ gr.Markdown(f"### {model['name']} ({model['type']})")
257
+ output = gr.Textbox(
258
+ label=f"{model['name']} Output",
259
+ lines=5,
260
+ interactive=False,
261
+ info=model["description"]
262
+ )
263
+ all_outputs.append(output)
264
+
265
+ # イベント接続
266
+ invoke_button.click(
267
  fn=self.handle_invoke,
268
+ inputs=[input_text, filter_checkboxes],
269
+ outputs=all_outputs
 
 
 
 
270
  )
271
+
272
  return demo
273
+
274
+ def handle_invoke(self, text, selected_types):
275
+ """モデル実行をハンドリング"""
276
+ return self.model_manager.run_models(text, selected_types)
277
+
278
+ # メインアプリケーションクラス
279
+ class ToxicityApp:
280
+ def __init__(self):
281
+ self.model_manager = ModelManager()
282
+ self.ui_manager = UIManager(self.model_manager)
283
+
284
+ def run(self):
285
+ """アプリを起動"""
286
+ demo = self.ui_manager.create_ui()
287
+ demo.launch()
288
 
289
  def main():
290
  app = ToxicityApp()
291
+ app.run()
 
292
 
293
  if __name__ == "__main__":
294
  main()