nyasukun commited on
Commit
385b096
·
1 Parent(s): 0da498f
Files changed (1) hide show
  1. app.py +202 -228
app.py CHANGED
@@ -43,252 +43,226 @@ CLASSIFICATION_MODELS = [
43
  }
44
  ]
45
 
46
- # GPU関連の装飾なしでクラスを定義
47
- class ModelManager:
48
- def __init__(self):
49
- self.tokenizers = {}
50
- self.pipelines = {}
51
- self.api_clients = {}
52
- self._initialize_api_clients()
53
- self._preload_local_models()
54
 
55
- def _initialize_api_clients(self):
56
- """Inference APIクライアントの初期化"""
57
- for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
58
- if model["type"] == "api" and "model_id" in model:
59
- logger.info(f"Initializing API client for {model['name']}")
60
- self.api_clients[model["model_id"]] = InferenceClient(
61
- model["model_id"],
62
- token=True # HFトークンを使用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
 
 
 
64
 
65
- def _preload_local_models(self):
66
- """ローカルモデルを事前ロード"""
67
- logger.info("Preloading local models at application startup...")
68
-
69
- # テキスト生成モデル
70
- for model in TEXT_GENERATION_MODELS:
71
- if model["type"] == "local" and "model_path" in model:
72
- model_path = model["model_path"]
73
- try:
74
- logger.info(f"Preloading text generation model: {model_path}")
75
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
76
- self.pipelines[model_path] = pipeline(
77
- "text-generation",
78
- model=model_path,
79
- tokenizer=self.tokenizers[model_path],
80
- torch_dtype=torch.bfloat16,
81
- trust_remote_code=True,
82
- device_map="auto"
83
- )
84
- logger.info(f"Model preloaded successfully: {model_path}")
85
- except Exception as e:
86
- logger.error(f"Error preloading model {model_path}: {str(e)}")
87
-
88
- # 分類モデル
89
- for model in CLASSIFICATION_MODELS:
90
- if model["type"] == "local" and "model_path" in model:
91
- model_path = model["model_path"]
92
- try:
93
- logger.info(f"Preloading classification model: {model_path}")
94
- self.tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
95
- self.pipelines[model_path] = pipeline(
96
- "text-classification",
97
- model=model_path,
98
- tokenizer=self.tokenizers[model_path],
99
- torch_dtype=torch.bfloat16,
100
- trust_remote_code=True,
101
- device_map="auto"
102
- )
103
- logger.info(f"Model preloaded successfully: {model_path}")
104
- except Exception as e:
105
- logger.error(f"Error preloading model {model_path}: {str(e)}")
106
 
107
- def generate_text_local(self, model_path, text):
108
- """ローカルモデルでのテキスト生成 (GPUデコレータはクラス外部で適用)"""
109
- try:
110
- logger.info(f"Running local text generation with {model_path}")
111
- outputs = self.pipelines[model_path](
112
- text,
113
- max_new_tokens=100,
114
- do_sample=False,
115
- num_return_sequences=1
116
- )
117
- return outputs[0]["generated_text"]
118
- except Exception as e:
119
- logger.error(f"Error in local text generation with {model_path}: {str(e)}")
120
- return f"Error: {str(e)}"
121
 
122
- def generate_text_api(self, model_id, text):
123
- """API経由でのテキスト生成"""
124
- try:
125
- logger.info(f"Running API text generation with {model_id}")
126
- response = self.api_clients[model_id].text_generation(
127
- text,
128
- max_new_tokens=100,
129
- temperature=0.7
130
- )
131
- return response
132
- except Exception as e:
133
- logger.error(f"Error in API text generation with {model_id}: {str(e)}")
134
- return f"Error: {str(e)}"
135
 
136
- def classify_text_local(self, model_path, text):
137
- """ローカルモデルでのテキスト分類 (GPUデコレータはクラス外部で適用)"""
138
- try:
139
- logger.info(f"Running local classification with {model_path}")
140
- result = self.pipelines[model_path](text)
141
- return str(result)
142
- except Exception as e:
143
- logger.error(f"Error in local classification with {model_path}: {str(e)}")
144
- return f"Error: {str(e)}"
145
 
146
- def classify_text_api(self, model_id, text):
147
- """API経由でのテキスト分類"""
148
- try:
149
- logger.info(f"Running API classification with {model_id}")
150
- response = self.api_clients[model_id].text_classification(text)
151
- return str(response)
152
- except Exception as e:
153
- logger.error(f"Error in API classification with {model_id}: {str(e)}")
154
- return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- def run_models(self, text, selected_types):
157
- """選択されたタイプのモデルで分析を実行"""
158
- results = []
 
 
 
 
 
159
 
160
- # テキスト生成モデルの実行
161
- for model in TEXT_GENERATION_MODELS:
162
- if model["type"] in selected_types:
163
- if model["type"] == "local":
164
- # クラス外部でGPUデコレータが適用される前提
165
- result = gpu_wrapper_generate(self, model["model_path"], text)
166
- else: # api
167
- result = self.generate_text_api(model["model_id"], text)
168
- results.append(f"{model['name']}: {result}")
169
 
170
- # 分類モデルの実行
171
- for model in CLASSIFICATION_MODELS:
172
- if model["type"] in selected_types:
173
- if model["type"] == "local":
174
- # クラス外部でGPUデコレータが適用される前提
175
- result = gpu_wrapper_classify(self, model["model_path"], text)
176
- else: # api
177
- result = self.classify_text_api(model["model_id"], text)
178
- results.append(f"{model['name']}: {result}")
179
 
180
- # 結果リストの長さを調整
181
- while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
182
- results.append("")
 
 
 
 
183
 
184
- return results
185
-
186
- # クラスのメソッドをラップしてGPU装飾子を適用
187
- @spaces.GPU
188
- def gpu_wrapper_generate(manager, model_path, text):
189
- return manager.generate_text_local(model_path, text)
190
-
191
- @spaces.GPU
192
- def gpu_wrapper_classify(manager, model_path, text):
193
- return manager.classify_text_local(model_path, text)
194
-
195
- # UIの作成と管理のためのクラス
196
- class UIManager:
197
- def __init__(self, model_manager):
198
- self.model_manager = model_manager
199
 
200
- def create_ui(self):
201
- """UIの作成"""
202
- with gr.Blocks() as demo:
203
- # ヘッダー
204
- gr.Markdown("""
205
- # Toxic Eye (Class-based Version with GPU Wrappers)
206
- This system evaluates the toxicity level of input text using both local models and Inference API.
207
- """)
208
-
209
- # 入力セクション
210
- with gr.Row():
211
- input_text = gr.Textbox(
212
- label="Input Text",
213
- placeholder="Enter text to analyze...",
214
- lines=3
215
- )
216
-
217
- # フィルターセクション
218
- with gr.Row():
219
- filter_checkboxes = gr.CheckboxGroup(
220
- choices=["local", "api"],
221
- value=["local", "api"],
222
- label="Filter Models",
223
- info="Choose which types of models to use",
224
- interactive=True
225
- )
226
 
227
- # 実行ボタン
228
- with gr.Row():
229
- invoke_button = gr.Button(
230
- "Analyze Text",
231
- variant="primary",
232
- size="lg"
233
- )
234
-
235
- # モデル出力表示エリア
236
- all_outputs = []
237
-
238
- with gr.Tabs():
239
- # テキスト生成モデルのタブ
240
- with gr.Tab("Text Generation Models"):
241
- for model in TEXT_GENERATION_MODELS:
242
- with gr.Group():
243
- gr.Markdown(f"### {model['name']} ({model['type']})")
244
- output = gr.Textbox(
245
- label=f"{model['name']} Output",
246
- lines=5,
247
- interactive=False,
248
- info=model["description"]
249
- )
250
- all_outputs.append(output)
251
-
252
- # 分類モデルのタブ
253
- with gr.Tab("Classification Models"):
254
- for model in CLASSIFICATION_MODELS:
255
- with gr.Group():
256
- gr.Markdown(f"### {model['name']} ({model['type']})")
257
- output = gr.Textbox(
258
- label=f"{model['name']} Output",
259
- lines=5,
260
- interactive=False,
261
- info=model["description"]
262
- )
263
- all_outputs.append(output)
264
-
265
- # イベント接続
266
- invoke_button.click(
267
- fn=self.handle_invoke,
268
- inputs=[input_text, filter_checkboxes],
269
- outputs=all_outputs
270
- )
271
 
272
- return demo
 
 
 
 
 
273
 
274
- def handle_invoke(self, text, selected_types):
275
- """モデル実行をハンドリング"""
276
- return self.model_manager.run_models(text, selected_types)
277
-
278
- # メインアプリケーションクラス
279
- class ToxicityApp:
280
- def __init__(self):
281
- self.model_manager = ModelManager()
282
- self.ui_manager = UIManager(self.model_manager)
283
-
284
- def run(self):
285
- """アプリを起動"""
286
- demo = self.ui_manager.create_ui()
287
- demo.launch()
288
 
289
  def main():
290
- app = ToxicityApp()
291
- app.run()
 
 
 
 
 
 
 
292
 
293
  if __name__ == "__main__":
294
  main()
 
43
  }
44
  ]
45
 
46
+ # グローバル変数でモデルとAPIクライアントを管理
47
+ tokenizers = {}
48
+ pipelines = {}
49
+ api_clients = {}
 
 
 
 
50
 
51
+ def initialize_api_clients():
52
+ """Inference APIクライアントの初期化"""
53
+ for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
54
+ if model["type"] == "api" and "model_id" in model:
55
+ logger.info(f"Initializing API client for {model['name']}")
56
+ api_clients[model["model_id"]] = InferenceClient(
57
+ model["model_id"],
58
+ token=True # HFトークンを使用
59
+ )
60
+
61
+ def preload_local_models():
62
+ """ローカルモデルを事前ロード"""
63
+ logger.info("Preloading local models at application startup...")
64
+
65
+ # テキスト生成モデル
66
+ for model in TEXT_GENERATION_MODELS:
67
+ if model["type"] == "local" and "model_path" in model:
68
+ model_path = model["model_path"]
69
+ try:
70
+ logger.info(f"Preloading text generation model: {model_path}")
71
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
72
+ pipelines[model_path] = pipeline(
73
+ "text-generation",
74
+ model=model_path,
75
+ tokenizer=tokenizers[model_path],
76
+ torch_dtype=torch.bfloat16,
77
+ trust_remote_code=True,
78
+ device_map="auto"
79
+ )
80
+ logger.info(f"Model preloaded successfully: {model_path}")
81
+ except Exception as e:
82
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
83
+
84
+ # 分類モデル
85
+ for model in CLASSIFICATION_MODELS:
86
+ if model["type"] == "local" and "model_path" in model:
87
+ model_path = model["model_path"]
88
+ try:
89
+ logger.info(f"Preloading classification model: {model_path}")
90
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
91
+ pipelines[model_path] = pipeline(
92
+ "text-classification",
93
+ model=model_path,
94
+ tokenizer=tokenizers[model_path],
95
+ torch_dtype=torch.bfloat16,
96
+ trust_remote_code=True,
97
+ device_map="auto"
98
  )
99
+ logger.info(f"Model preloaded successfully: {model_path}")
100
+ except Exception as e:
101
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
102
 
103
+ @spaces.GPU
104
+ def generate_text_local(model_path, text):
105
+ """ローカルモデルでのテキスト生成"""
106
+ try:
107
+ logger.info(f"Running local text generation with {model_path}")
108
+ outputs = pipelines[model_path](
109
+ text,
110
+ max_new_tokens=100,
111
+ do_sample=False,
112
+ num_return_sequences=1
113
+ )
114
+ return outputs[0]["generated_text"]
115
+ except Exception as e:
116
+ logger.error(f"Error in local text generation with {model_path}: {str(e)}")
117
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def generate_text_api(model_id, text):
120
+ """API経由でのテキスト生成"""
121
+ try:
122
+ logger.info(f"Running API text generation with {model_id}")
123
+ response = api_clients[model_id].text_generation(
124
+ text,
125
+ max_new_tokens=100,
126
+ temperature=0.7
127
+ )
128
+ return response
129
+ except Exception as e:
130
+ logger.error(f"Error in API text generation with {model_id}: {str(e)}")
131
+ return f"Error: {str(e)}"
 
132
 
133
+ @spaces.GPU
134
+ def classify_text_local(model_path, text):
135
+ """ローカルモデルでのテキスト分類"""
136
+ try:
137
+ logger.info(f"Running local classification with {model_path}")
138
+ result = pipelines[model_path](text)
139
+ return str(result)
140
+ except Exception as e:
141
+ logger.error(f"Error in local classification with {model_path}: {str(e)}")
142
+ return f"Error: {str(e)}"
 
 
 
143
 
144
+ def classify_text_api(model_id, text):
145
+ """API経由でのテキスト分類"""
146
+ try:
147
+ logger.info(f"Running API classification with {model_id}")
148
+ response = api_clients[model_id].text_classification(text)
149
+ return str(response)
150
+ except Exception as e:
151
+ logger.error(f"Error in API classification with {model_id}: {str(e)}")
152
+ return f"Error: {str(e)}"
153
 
154
+ def handle_invoke(text, selected_types):
155
+ """選択されたタイプのモデルで分析を実行"""
156
+ results = []
157
+
158
+ # テキスト生成モデルの実行
159
+ for model in TEXT_GENERATION_MODELS:
160
+ if model["type"] in selected_types:
161
+ if model["type"] == "local":
162
+ result = generate_text_local(model["model_path"], text)
163
+ else: # api
164
+ result = generate_text_api(model["model_id"], text)
165
+ results.append(f"{model['name']}: {result}")
166
+
167
+ # 分類モデルの実行
168
+ for model in CLASSIFICATION_MODELS:
169
+ if model["type"] in selected_types:
170
+ if model["type"] == "local":
171
+ result = classify_text_local(model["model_path"], text)
172
+ else: # api
173
+ result = classify_text_api(model["model_id"], text)
174
+ results.append(f"{model['name']}: {result}")
175
+
176
+ # 結果リストの長さを調整
177
+ while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
178
+ results.append("")
179
+
180
+ return results
181
 
182
+ def create_ui():
183
+ """UIの作成"""
184
+ with gr.Blocks() as demo:
185
+ # ヘッダー
186
+ gr.Markdown("""
187
+ # Toxic Eye (Function-based Version)
188
+ This system evaluates the toxicity level of input text using both local models and Inference API.
189
+ """)
190
 
191
+ # 入力セクション
192
+ with gr.Row():
193
+ input_text = gr.Textbox(
194
+ label="Input Text",
195
+ placeholder="Enter text to analyze...",
196
+ lines=3
197
+ )
 
 
198
 
199
+ # フィルターセクション
200
+ with gr.Row():
201
+ filter_checkboxes = gr.CheckboxGroup(
202
+ choices=["local", "api"],
203
+ value=["local", "api"],
204
+ label="Filter Models",
205
+ info="Choose which types of models to use",
206
+ interactive=True
207
+ )
208
 
209
+ # 実行ボタン
210
+ with gr.Row():
211
+ invoke_button = gr.Button(
212
+ "Analyze Text",
213
+ variant="primary",
214
+ size="lg"
215
+ )
216
 
217
+ # モデル出力表示エリア
218
+ all_outputs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ with gr.Tabs():
221
+ # テキスト生成モデルのタブ
222
+ with gr.Tab("Text Generation Models"):
223
+ for model in TEXT_GENERATION_MODELS:
224
+ with gr.Group():
225
+ gr.Markdown(f"### {model['name']} ({model['type']})")
226
+ output = gr.Textbox(
227
+ label=f"{model['name']} Output",
228
+ lines=5,
229
+ interactive=False,
230
+ info=model["description"]
231
+ )
232
+ all_outputs.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # 分類モデルのタブ
235
+ with gr.Tab("Classification Models"):
236
+ for model in CLASSIFICATION_MODELS:
237
+ with gr.Group():
238
+ gr.Markdown(f"### {model['name']} ({model['type']})")
239
+ output = gr.Textbox(
240
+ label=f"{model['name']} Output",
241
+ lines=5,
242
+ interactive=False,
243
+ info=model["description"]
244
+ )
245
+ all_outputs.append(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # イベント接続
248
+ invoke_button.click(
249
+ fn=handle_invoke,
250
+ inputs=[input_text, filter_checkboxes],
251
+ outputs=all_outputs
252
+ )
253
 
254
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def main():
257
+ # APIクライアントの初期化
258
+ initialize_api_clients()
259
+
260
+ # ローカルモデルを事前ロード
261
+ preload_local_models()
262
+
263
+ # UIを作成して起動
264
+ demo = create_ui()
265
+ demo.launch()
266
 
267
  if __name__ == "__main__":
268
  main()