Lyon28 commited on
Commit
61b2823
·
verified ·
1 Parent(s): 9d16a35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +365 -0
app.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Dict, List, Optional, Any
5
+ import torch
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ AutoModelForSequenceClassification,
10
+ AutoModelForTokenClassification,
11
+ AutoModel,
12
+ pipeline,
13
+ T5ForConditionalGeneration,
14
+ T5Tokenizer
15
+ )
16
+ import gradio as gr
17
+ from flask import Flask, request, jsonify
18
+ import threading
19
+ import time
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class MultiModelAPI:
26
+ def __init__(self):
27
+ self.models = {}
28
+ self.tokenizers = {}
29
+ self.pipelines = {}
30
+ self.model_configs = {
31
+ 'Lyon28/Tinny-Llama': 'causal-lm',
32
+ 'Lyon28/Pythia': 'causal-lm',
33
+ 'Lyon28/Bert-Tinny': 'feature-extraction',
34
+ 'Lyon28/Albert-Base-V2': 'feature-extraction',
35
+ 'Lyon28/T5-Small': 'text2text-generation',
36
+ 'Lyon28/GPT-2': 'causal-lm',
37
+ 'Lyon28/GPT-Neo': 'causal-lm',
38
+ 'Lyon28/Distilbert-Base-Uncased': 'feature-extraction',
39
+ 'Lyon28/Distil_GPT-2': 'causal-lm',
40
+ 'Lyon28/GPT-2-Tinny': 'causal-lm',
41
+ 'Lyon28/Electra-Small': 'feature-extraction'
42
+ }
43
+
44
+ def load_model(self, model_name: str):
45
+ """Load a specific model"""
46
+ try:
47
+ logger.info(f"Loading model: {model_name}")
48
+
49
+ if model_name in self.models:
50
+ logger.info(f"Model {model_name} already loaded")
51
+ return True
52
+
53
+ model_type = self.model_configs.get(model_name, 'causal-lm')
54
+
55
+ # Load tokenizer
56
+ tokenizer = AutoTokenizer.from_pretrained(
57
+ model_name,
58
+ trust_remote_code=True,
59
+ cache_dir="/app/cache"
60
+ )
61
+
62
+ # Add pad token if not exists
63
+ if tokenizer.pad_token is None:
64
+ tokenizer.pad_token = tokenizer.eos_token
65
+
66
+ # Load model based on type
67
+ if model_type == 'causal-lm':
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_name,
70
+ trust_remote_code=True,
71
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
72
+ device_map="auto" if torch.cuda.is_available() else None,
73
+ cache_dir="/app/cache"
74
+ )
75
+ # Create pipeline
76
+ pipe = pipeline(
77
+ "text-generation",
78
+ model=model,
79
+ tokenizer=tokenizer,
80
+ device=0 if torch.cuda.is_available() else -1
81
+ )
82
+
83
+ elif model_type == 'text2text-generation':
84
+ model = T5ForConditionalGeneration.from_pretrained(
85
+ model_name,
86
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
87
+ cache_dir="/app/cache"
88
+ )
89
+ pipe = pipeline(
90
+ "text2text-generation",
91
+ model=model,
92
+ tokenizer=tokenizer,
93
+ device=0 if torch.cuda.is_available() else -1
94
+ )
95
+
96
+ else: # feature-extraction or other BERT-like models
97
+ model = AutoModel.from_pretrained(
98
+ model_name,
99
+ trust_remote_code=True,
100
+ cache_dir="/app/cache"
101
+ )
102
+ pipe = pipeline(
103
+ "feature-extraction",
104
+ model=model,
105
+ tokenizer=tokenizer,
106
+ device=0 if torch.cuda.is_available() else -1
107
+ )
108
+
109
+ self.models[model_name] = model
110
+ self.tokenizers[model_name] = tokenizer
111
+ self.pipelines[model_name] = pipe
112
+
113
+ logger.info(f"Successfully loaded model: {model_name}")
114
+ return True
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error loading model {model_name}: {str(e)}")
118
+ return False
119
+
120
+ def generate_text(self, model_name: str, prompt: str, **kwargs):
121
+ """Generate text using specified model"""
122
+ try:
123
+ if model_name not in self.pipelines:
124
+ if not self.load_model(model_name):
125
+ return {"error": f"Failed to load model {model_name}"}
126
+
127
+ pipe = self.pipelines[model_name]
128
+ model_type = self.model_configs.get(model_name, 'causal-lm')
129
+
130
+ # Set default parameters
131
+ max_length = kwargs.get('max_length', 100)
132
+ temperature = kwargs.get('temperature', 0.7)
133
+ top_p = kwargs.get('top_p', 0.9)
134
+ do_sample = kwargs.get('do_sample', True)
135
+
136
+ if model_type == 'causal-lm':
137
+ result = pipe(
138
+ prompt,
139
+ max_length=max_length,
140
+ temperature=temperature,
141
+ top_p=top_p,
142
+ do_sample=do_sample,
143
+ pad_token_id=pipe.tokenizer.eos_token_id
144
+ )
145
+ return {"generated_text": result[0]['generated_text']}
146
+
147
+ elif model_type == 'text2text-generation':
148
+ result = pipe(
149
+ prompt,
150
+ max_length=max_length,
151
+ temperature=temperature,
152
+ do_sample=do_sample
153
+ )
154
+ return {"generated_text": result[0]['generated_text']}
155
+
156
+ else: # feature extraction
157
+ result = pipe(prompt)
158
+ return {"embeddings": result}
159
+
160
+ except Exception as e:
161
+ logger.error(f"Error generating text with {model_name}: {str(e)}")
162
+ return {"error": str(e)}
163
+
164
+ def get_model_info(self):
165
+ """Get information about loaded models"""
166
+ return {
167
+ "available_models": list(self.model_configs.keys()),
168
+ "loaded_models": list(self.models.keys()),
169
+ "model_types": self.model_configs
170
+ }
171
+
172
+ # Initialize API
173
+ api = MultiModelAPI()
174
+
175
+ # Flask API
176
+ app = Flask(__name__)
177
+
178
+ @app.route('/api/models', methods=['GET'])
179
+ def get_models():
180
+ """Get available models"""
181
+ return jsonify(api.get_model_info())
182
+
183
+ @app.route('/api/load_model', methods=['POST'])
184
+ def load_model():
185
+ """Load a specific model"""
186
+ data = request.json
187
+ model_name = data.get('model_name')
188
+
189
+ if not model_name:
190
+ return jsonify({"error": "model_name is required"}), 400
191
+
192
+ success = api.load_model(model_name)
193
+ if success:
194
+ return jsonify({"message": f"Model {model_name} loaded successfully"})
195
+ else:
196
+ return jsonify({"error": f"Failed to load model {model_name}"}), 500
197
+
198
+ @app.route('/api/generate', methods=['POST'])
199
+ def generate():
200
+ """Generate text using specified model"""
201
+ data = request.json
202
+ model_name = data.get('model_name')
203
+ prompt = data.get('prompt')
204
+
205
+ if not model_name or not prompt:
206
+ return jsonify({"error": "model_name and prompt are required"}), 400
207
+
208
+ # Extract generation parameters
209
+ params = {
210
+ 'max_length': data.get('max_length', 100),
211
+ 'temperature': data.get('temperature', 0.7),
212
+ 'top_p': data.get('top_p', 0.9),
213
+ 'do_sample': data.get('do_sample', True)
214
+ }
215
+
216
+ result = api.generate_text(model_name, prompt, **params)
217
+ return jsonify(result)
218
+
219
+ @app.route('/health', methods=['GET'])
220
+ def health_check():
221
+ """Health check endpoint"""
222
+ return jsonify({"status": "healthy", "loaded_models": len(api.models)})
223
+
224
+ # Gradio Interface
225
+ def gradio_interface():
226
+ def generate_text_ui(model_name, prompt, max_length, temperature, top_p):
227
+ if not model_name or not prompt:
228
+ return "Please select a model and enter a prompt"
229
+
230
+ params = {
231
+ 'max_length': int(max_length),
232
+ 'temperature': float(temperature),
233
+ 'top_p': float(top_p),
234
+ 'do_sample': True
235
+ }
236
+
237
+ result = api.generate_text(model_name, prompt, **params)
238
+
239
+ if 'error' in result:
240
+ return f"Error: {result['error']}"
241
+
242
+ return result.get('generated_text', str(result))
243
+
244
+ def load_model_ui(model_name):
245
+ if not model_name:
246
+ return "Please select a model"
247
+
248
+ success = api.load_model(model_name)
249
+ if success:
250
+ return f"✅ Model {model_name} loaded successfully"
251
+ else:
252
+ return f"❌ Failed to load model {model_name}"
253
+
254
+ with gr.Blocks(title="Multi-Model API Interface") as interface:
255
+ gr.Markdown("# Multi-Model API Interface")
256
+ gr.Markdown("Load and interact with multiple Hugging Face models")
257
+
258
+ with gr.Tab("Model Management"):
259
+ model_dropdown = gr.Dropdown(
260
+ choices=list(api.model_configs.keys()),
261
+ label="Select Model",
262
+ value=None
263
+ )
264
+ load_btn = gr.Button("Load Model")
265
+ load_status = gr.Textbox(label="Status", interactive=False)
266
+
267
+ load_btn.click(
268
+ load_model_ui,
269
+ inputs=[model_dropdown],
270
+ outputs=[load_status]
271
+ )
272
+
273
+ with gr.Tab("Text Generation"):
274
+ with gr.Row():
275
+ with gr.Column():
276
+ gen_model = gr.Dropdown(
277
+ choices=list(api.model_configs.keys()),
278
+ label="Model",
279
+ value=None
280
+ )
281
+ prompt_input = gr.Textbox(
282
+ label="Prompt",
283
+ placeholder="Enter your prompt here...",
284
+ lines=3
285
+ )
286
+
287
+ with gr.Row():
288
+ max_length = gr.Slider(10, 500, value=100, label="Max Length")
289
+ temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
290
+ top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P")
291
+
292
+ generate_btn = gr.Button("Generate")
293
+
294
+ with gr.Column():
295
+ output_text = gr.Textbox(
296
+ label="Generated Text",
297
+ lines=10,
298
+ interactive=False
299
+ )
300
+
301
+ generate_btn.click(
302
+ generate_text_ui,
303
+ inputs=[gen_model, prompt_input, max_length, temperature, top_p],
304
+ outputs=[output_text]
305
+ )
306
+
307
+ with gr.Tab("API Documentation"):
308
+ gr.Markdown("""
309
+ ## API Endpoints
310
+
311
+ ### GET /api/models
312
+ Get list of available and loaded models
313
+
314
+ ### POST /api/load_model
315
+ Load a specific model
316
+ ```json
317
+ {
318
+ "model_name": "Lyon28/GPT-2"
319
+ }
320
+ ```
321
+
322
+ ### POST /api/generate
323
+ Generate text using a model
324
+ ```json
325
+ {
326
+ "model_name": "Lyon28/GPT-2",
327
+ "prompt": "Hello world",
328
+ "max_length": 100,
329
+ "temperature": 0.7,
330
+ "top_p": 0.9,
331
+ "do_sample": true
332
+ }
333
+ ```
334
+
335
+ ### GET /health
336
+ Health check endpoint
337
+ """)
338
+
339
+ return interface
340
+
341
+ def run_flask():
342
+ """Run Flask API server"""
343
+ app.run(host="0.0.0.0", port=5000, debug=False)
344
+
345
+ def main():
346
+ """Main function to run both Flask and Gradio"""
347
+ # Start Flask in a separate thread
348
+ flask_thread = threading.Thread(target=run_flask, daemon=True)
349
+ flask_thread.start()
350
+
351
+ # Give Flask time to start
352
+ time.sleep(2)
353
+
354
+ # Create and launch Gradio interface
355
+ interface = gradio_interface()
356
+
357
+ # Launch Gradio on port 7860 (HF Spaces default)
358
+ interface.launch(
359
+ server_name="0.0.0.0",
360
+ server_port=7860,
361
+ share=False
362
+ )
363
+
364
+ if __name__ == "__main__":
365
+ main()