nyasukun commited on
Commit
e09e1bb
·
verified ·
1 Parent(s): b37f1c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -63
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, pipeline
 
4
  import logging
5
  import spaces
6
 
@@ -11,17 +12,25 @@ logging.basicConfig(
11
  )
12
  logger = logging.getLogger(__name__)
13
 
14
- # シンプルなモデル定義(3つのローカルモデル)
15
  TEXT_GENERATION_MODELS = [
16
  {
17
  "name": "Llama-2",
18
  "description": "Known for its robust performance in content analysis",
 
19
  "model_path": "meta-llama/Llama-2-7b-hf"
20
  },
21
  {
22
  "name": "Mistral-7B",
23
  "description": "Offers precise and detailed text evaluation",
 
24
  "model_path": "mistralai/Mistral-7B-v0.1"
 
 
 
 
 
 
25
  }
26
  ]
27
 
@@ -29,59 +38,73 @@ CLASSIFICATION_MODELS = [
29
  {
30
  "name": "Toxic-BERT",
31
  "description": "Fine-tuned for toxic content detection",
 
32
  "model_path": "unitary/toxic-bert"
33
  }
34
  ]
35
 
36
- # グローバル変数でモデルとトークナイザを管理
37
  tokenizers = {}
38
  pipelines = {}
 
39
 
40
- def preload_models():
41
- """アプリケーション起動時にモデルを事前ロード"""
42
- logger.info("Preloading models at application startup...")
 
 
 
 
 
 
 
 
 
 
43
 
44
  # テキスト生成モデル
45
  for model in TEXT_GENERATION_MODELS:
46
- model_path = model["model_path"]
47
- try:
48
- logger.info(f"Preloading text generation model: {model_path}")
49
- tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
50
- pipelines[model_path] = pipeline(
51
- "text-generation",
52
- model=model_path,
53
- tokenizer=tokenizers[model_path],
54
- torch_dtype=torch.bfloat16,
55
- trust_remote_code=True,
56
- device_map="auto"
57
- )
58
- logger.info(f"Model preloaded successfully: {model_path}")
59
- except Exception as e:
60
- logger.error(f"Error preloading model {model_path}: {str(e)}")
 
61
 
62
  # 分類モデル
63
  for model in CLASSIFICATION_MODELS:
64
- model_path = model["model_path"]
65
- try:
66
- logger.info(f"Preloading classification model: {model_path}")
67
- tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
68
- pipelines[model_path] = pipeline(
69
- "text-classification",
70
- model=model_path,
71
- tokenizer=tokenizers[model_path],
72
- torch_dtype=torch.bfloat16,
73
- trust_remote_code=True,
74
- device_map="auto"
75
- )
76
- logger.info(f"Model preloaded successfully: {model_path}")
77
- except Exception as e:
78
- logger.error(f"Error preloading model {model_path}: {str(e)}")
 
79
 
80
  @spaces.GPU
81
- def generate_text(model_path, text):
82
- """テキスト生成の実行"""
83
  try:
84
- logger.info(f"Running text generation with {model_path}")
85
  outputs = pipelines[model_path](
86
  text,
87
  max_new_tokens=100,
@@ -90,35 +113,69 @@ def generate_text(model_path, text):
90
  )
91
  return outputs[0]["generated_text"]
92
  except Exception as e:
93
- logger.error(f"Error in text generation with {model_path}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return f"Error: {str(e)}"
95
 
96
  @spaces.GPU
97
- def classify_text(model_path, text):
98
- """テキスト分類の実行"""
99
  try:
100
- logger.info(f"Running classification with {model_path}")
101
  result = pipelines[model_path](text)
102
  return str(result)
103
  except Exception as e:
104
- logger.error(f"Error in classification with {model_path}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
105
  return f"Error: {str(e)}"
106
 
107
- def handle_invoke(text):
108
- """すべてのモデルで分析を実行"""
109
  results = []
110
 
111
  # テキスト生成モデルの実行
112
  for model in TEXT_GENERATION_MODELS:
113
- model_path = model["model_path"]
114
- result = generate_text(model_path, text)
115
- results.append(result)
 
 
 
116
 
117
  # 分類モデルの実行
118
  for model in CLASSIFICATION_MODELS:
119
- model_path = model["model_path"]
120
- result = classify_text(model_path, text)
121
- results.append(result)
 
 
 
 
 
 
 
122
 
123
  return results
124
 
@@ -127,8 +184,8 @@ def create_ui():
127
  with gr.Blocks() as demo:
128
  # ヘッダー
129
  gr.Markdown("""
130
- # Toxic Eye (3 Models Version)
131
- This system evaluates the toxicity level of input text using 3 local models.
132
  """)
133
 
134
  # 入力セクション
@@ -139,6 +196,16 @@ def create_ui():
139
  lines=3
140
  )
141
 
 
 
 
 
 
 
 
 
 
 
142
  # 実行ボタン
143
  with gr.Row():
144
  invoke_button = gr.Button(
@@ -148,48 +215,50 @@ def create_ui():
148
  )
149
 
150
  # モデル出力表���エリア
151
- gen_outputs = []
152
- class_outputs = []
153
 
154
  with gr.Tabs():
155
  # テキスト生成モデルのタブ
156
  with gr.Tab("Text Generation Models"):
157
  for model in TEXT_GENERATION_MODELS:
158
  with gr.Group():
159
- gr.Markdown(f"### {model['name']}")
160
  output = gr.Textbox(
161
  label=f"{model['name']} Output",
162
  lines=5,
163
  interactive=False,
164
  info=model["description"]
165
  )
166
- gen_outputs.append(output)
167
 
168
  # 分類モデルのタブ
169
  with gr.Tab("Classification Models"):
170
  for model in CLASSIFICATION_MODELS:
171
  with gr.Group():
172
- gr.Markdown(f"### {model['name']}")
173
  output = gr.Textbox(
174
  label=f"{model['name']} Output",
175
  lines=5,
176
  interactive=False,
177
  info=model["description"]
178
  )
179
- class_outputs.append(output)
180
 
181
  # イベント接続
182
  invoke_button.click(
183
  fn=handle_invoke,
184
- inputs=[input_text],
185
- outputs=gen_outputs + class_outputs
186
  )
187
 
188
  return demo
189
 
190
  def main():
191
- # モデルを事前ロード
192
- preload_models()
 
 
 
193
 
194
  # UIを作成して起動
195
  demo = create_ui()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, pipeline
4
+ from huggingface_hub import InferenceClient
5
  import logging
6
  import spaces
7
 
 
12
  )
13
  logger = logging.getLogger(__name__)
14
 
15
+ # モデル定義(ローカルモデルとAPIモデルの両方)
16
  TEXT_GENERATION_MODELS = [
17
  {
18
  "name": "Llama-2",
19
  "description": "Known for its robust performance in content analysis",
20
+ "type": "local",
21
  "model_path": "meta-llama/Llama-2-7b-hf"
22
  },
23
  {
24
  "name": "Mistral-7B",
25
  "description": "Offers precise and detailed text evaluation",
26
+ "type": "local",
27
  "model_path": "mistralai/Mistral-7B-v0.1"
28
+ },
29
+ {
30
+ "name": "Zephyr-7B",
31
+ "description": "Specialized in understanding context and nuance",
32
+ "type": "api",
33
+ "model_id": "HuggingFaceH4/zephyr-7b-beta"
34
  }
35
  ]
36
 
 
38
  {
39
  "name": "Toxic-BERT",
40
  "description": "Fine-tuned for toxic content detection",
41
+ "type": "local",
42
  "model_path": "unitary/toxic-bert"
43
  }
44
  ]
45
 
46
+ # グローバル変数でモデルとAPIクライアントを管理
47
  tokenizers = {}
48
  pipelines = {}
49
+ api_clients = {}
50
 
51
+ def initialize_api_clients():
52
+ """Inference APIクライアントの初期化"""
53
+ for model in TEXT_GENERATION_MODELS + CLASSIFICATION_MODELS:
54
+ if model["type"] == "api" and "model_id" in model:
55
+ logger.info(f"Initializing API client for {model['name']}")
56
+ api_clients[model["model_id"]] = InferenceClient(
57
+ model["model_id"],
58
+ token=True # HFトークンを使用
59
+ )
60
+
61
+ def preload_local_models():
62
+ """ローカルモデルを事前ロード"""
63
+ logger.info("Preloading local models at application startup...")
64
 
65
  # テキスト生成モデル
66
  for model in TEXT_GENERATION_MODELS:
67
+ if model["type"] == "local" and "model_path" in model:
68
+ model_path = model["model_path"]
69
+ try:
70
+ logger.info(f"Preloading text generation model: {model_path}")
71
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
72
+ pipelines[model_path] = pipeline(
73
+ "text-generation",
74
+ model=model_path,
75
+ tokenizer=tokenizers[model_path],
76
+ torch_dtype=torch.bfloat16,
77
+ trust_remote_code=True,
78
+ device_map="auto"
79
+ )
80
+ logger.info(f"Model preloaded successfully: {model_path}")
81
+ except Exception as e:
82
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
83
 
84
  # 分類モデル
85
  for model in CLASSIFICATION_MODELS:
86
+ if model["type"] == "local" and "model_path" in model:
87
+ model_path = model["model_path"]
88
+ try:
89
+ logger.info(f"Preloading classification model: {model_path}")
90
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
91
+ pipelines[model_path] = pipeline(
92
+ "text-classification",
93
+ model=model_path,
94
+ tokenizer=tokenizers[model_path],
95
+ torch_dtype=torch.bfloat16,
96
+ trust_remote_code=True,
97
+ device_map="auto"
98
+ )
99
+ logger.info(f"Model preloaded successfully: {model_path}")
100
+ except Exception as e:
101
+ logger.error(f"Error preloading model {model_path}: {str(e)}")
102
 
103
  @spaces.GPU
104
+ def generate_text_local(model_path, text):
105
+ """ローカルモデルでのテキスト生成"""
106
  try:
107
+ logger.info(f"Running local text generation with {model_path}")
108
  outputs = pipelines[model_path](
109
  text,
110
  max_new_tokens=100,
 
113
  )
114
  return outputs[0]["generated_text"]
115
  except Exception as e:
116
+ logger.error(f"Error in local text generation with {model_path}: {str(e)}")
117
+ return f"Error: {str(e)}"
118
+
119
+ def generate_text_api(model_id, text):
120
+ """API経由でのテキスト生成"""
121
+ try:
122
+ logger.info(f"Running API text generation with {model_id}")
123
+ response = api_clients[model_id].text_generation(
124
+ text,
125
+ max_new_tokens=100,
126
+ temperature=0.7
127
+ )
128
+ return response
129
+ except Exception as e:
130
+ logger.error(f"Error in API text generation with {model_id}: {str(e)}")
131
  return f"Error: {str(e)}"
132
 
133
  @spaces.GPU
134
+ def classify_text_local(model_path, text):
135
+ """ローカルモデルでのテキスト分類"""
136
  try:
137
+ logger.info(f"Running local classification with {model_path}")
138
  result = pipelines[model_path](text)
139
  return str(result)
140
  except Exception as e:
141
+ logger.error(f"Error in local classification with {model_path}: {str(e)}")
142
+ return f"Error: {str(e)}"
143
+
144
+ def classify_text_api(model_id, text):
145
+ """API経由でのテキスト分類"""
146
+ try:
147
+ logger.info(f"Running API classification with {model_id}")
148
+ response = api_clients[model_id].text_classification(text)
149
+ return str(response)
150
+ except Exception as e:
151
+ logger.error(f"Error in API classification with {model_id}: {str(e)}")
152
  return f"Error: {str(e)}"
153
 
154
+ def handle_invoke(text, selected_types):
155
+ """選択されたタイプのモデルで分析を実行"""
156
  results = []
157
 
158
  # テキスト生成モデルの実行
159
  for model in TEXT_GENERATION_MODELS:
160
+ if model["type"] in selected_types:
161
+ if model["type"] == "local":
162
+ result = generate_text_local(model["model_path"], text)
163
+ else: # api
164
+ result = generate_text_api(model["model_id"], text)
165
+ results.append(f"{model['name']}: {result}")
166
 
167
  # 分類モデルの実行
168
  for model in CLASSIFICATION_MODELS:
169
+ if model["type"] in selected_types:
170
+ if model["type"] == "local":
171
+ result = classify_text_local(model["model_path"], text)
172
+ else: # api
173
+ result = classify_text_api(model["model_id"], text)
174
+ results.append(f"{model['name']}: {result}")
175
+
176
+ # 結果リストの長さを調整
177
+ while len(results) < len(TEXT_GENERATION_MODELS) + len(CLASSIFICATION_MODELS):
178
+ results.append("")
179
 
180
  return results
181
 
 
184
  with gr.Blocks() as demo:
185
  # ヘッダー
186
  gr.Markdown("""
187
+ # Toxic Eye (Local + API Version)
188
+ This system evaluates the toxicity level of input text using both local models and Inference API.
189
  """)
190
 
191
  # 入力セクション
 
196
  lines=3
197
  )
198
 
199
+ # フィルターセクション
200
+ with gr.Row():
201
+ filter_checkboxes = gr.CheckboxGroup(
202
+ choices=["local", "api"],
203
+ value=["local", "api"],
204
+ label="Filter Models",
205
+ info="Choose which types of models to use",
206
+ interactive=True
207
+ )
208
+
209
  # 実行ボタン
210
  with gr.Row():
211
  invoke_button = gr.Button(
 
215
  )
216
 
217
  # モデル出力表���エリア
218
+ all_outputs = []
 
219
 
220
  with gr.Tabs():
221
  # テキスト生成モデルのタブ
222
  with gr.Tab("Text Generation Models"):
223
  for model in TEXT_GENERATION_MODELS:
224
  with gr.Group():
225
+ gr.Markdown(f"### {model['name']} ({model['type']})")
226
  output = gr.Textbox(
227
  label=f"{model['name']} Output",
228
  lines=5,
229
  interactive=False,
230
  info=model["description"]
231
  )
232
+ all_outputs.append(output)
233
 
234
  # 分類モデルのタブ
235
  with gr.Tab("Classification Models"):
236
  for model in CLASSIFICATION_MODELS:
237
  with gr.Group():
238
+ gr.Markdown(f"### {model['name']} ({model['type']})")
239
  output = gr.Textbox(
240
  label=f"{model['name']} Output",
241
  lines=5,
242
  interactive=False,
243
  info=model["description"]
244
  )
245
+ all_outputs.append(output)
246
 
247
  # イベント接続
248
  invoke_button.click(
249
  fn=handle_invoke,
250
+ inputs=[input_text, filter_checkboxes],
251
+ outputs=all_outputs
252
  )
253
 
254
  return demo
255
 
256
  def main():
257
+ # APIクライアントの初期化
258
+ initialize_api_clients()
259
+
260
+ # ローカルモデルを事前ロード
261
+ preload_local_models()
262
 
263
  # UIを作成して起動
264
  demo = create_ui()