File size: 12,380 Bytes
61b2823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import os
import json
import logging
from typing import Dict, List, Optional, Any
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    AutoModel,
    pipeline,
    T5ForConditionalGeneration,
    T5Tokenizer
)
import gradio as gr
from flask import Flask, request, jsonify
import threading
import time

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MultiModelAPI:
    def __init__(self):
        self.models = {}
        self.tokenizers = {}
        self.pipelines = {}
        self.model_configs = {
            'Lyon28/Tinny-Llama': 'causal-lm',
            'Lyon28/Pythia': 'causal-lm', 
            'Lyon28/Bert-Tinny': 'feature-extraction',
            'Lyon28/Albert-Base-V2': 'feature-extraction',
            'Lyon28/T5-Small': 'text2text-generation',
            'Lyon28/GPT-2': 'causal-lm',
            'Lyon28/GPT-Neo': 'causal-lm',
            'Lyon28/Distilbert-Base-Uncased': 'feature-extraction',
            'Lyon28/Distil_GPT-2': 'causal-lm',
            'Lyon28/GPT-2-Tinny': 'causal-lm',
            'Lyon28/Electra-Small': 'feature-extraction'
        }
        
    def load_model(self, model_name: str):
        """Load a specific model"""
        try:
            logger.info(f"Loading model: {model_name}")
            
            if model_name in self.models:
                logger.info(f"Model {model_name} already loaded")
                return True
                
            model_type = self.model_configs.get(model_name, 'causal-lm')
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(
                model_name, 
                trust_remote_code=True,
                cache_dir="/app/cache"
            )
            
            # Add pad token if not exists
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Load model based on type
            if model_type == 'causal-lm':
                model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    device_map="auto" if torch.cuda.is_available() else None,
                    cache_dir="/app/cache"
                )
                # Create pipeline
                pipe = pipeline(
                    "text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    device=0 if torch.cuda.is_available() else -1
                )
                
            elif model_type == 'text2text-generation':
                model = T5ForConditionalGeneration.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    cache_dir="/app/cache"
                )
                pipe = pipeline(
                    "text2text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    device=0 if torch.cuda.is_available() else -1
                )
                
            else:  # feature-extraction or other BERT-like models
                model = AutoModel.from_pretrained(
                    model_name,
                    trust_remote_code=True,
                    cache_dir="/app/cache"
                )
                pipe = pipeline(
                    "feature-extraction",
                    model=model,
                    tokenizer=tokenizer,
                    device=0 if torch.cuda.is_available() else -1
                )
            
            self.models[model_name] = model
            self.tokenizers[model_name] = tokenizer
            self.pipelines[model_name] = pipe
            
            logger.info(f"Successfully loaded model: {model_name}")
            return True
            
        except Exception as e:
            logger.error(f"Error loading model {model_name}: {str(e)}")
            return False
    
    def generate_text(self, model_name: str, prompt: str, **kwargs):
        """Generate text using specified model"""
        try:
            if model_name not in self.pipelines:
                if not self.load_model(model_name):
                    return {"error": f"Failed to load model {model_name}"}
            
            pipe = self.pipelines[model_name]
            model_type = self.model_configs.get(model_name, 'causal-lm')
            
            # Set default parameters
            max_length = kwargs.get('max_length', 100)
            temperature = kwargs.get('temperature', 0.7)
            top_p = kwargs.get('top_p', 0.9)
            do_sample = kwargs.get('do_sample', True)
            
            if model_type == 'causal-lm':
                result = pipe(
                    prompt,
                    max_length=max_length,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=do_sample,
                    pad_token_id=pipe.tokenizer.eos_token_id
                )
                return {"generated_text": result[0]['generated_text']}
                
            elif model_type == 'text2text-generation':
                result = pipe(
                    prompt,
                    max_length=max_length,
                    temperature=temperature,
                    do_sample=do_sample
                )
                return {"generated_text": result[0]['generated_text']}
                
            else:  # feature extraction
                result = pipe(prompt)
                return {"embeddings": result}
                
        except Exception as e:
            logger.error(f"Error generating text with {model_name}: {str(e)}")
            return {"error": str(e)}
    
    def get_model_info(self):
        """Get information about loaded models"""
        return {
            "available_models": list(self.model_configs.keys()),
            "loaded_models": list(self.models.keys()),
            "model_types": self.model_configs
        }

# Initialize API
api = MultiModelAPI()

# Flask API
app = Flask(__name__)

@app.route('/api/models', methods=['GET'])
def get_models():
    """Get available models"""
    return jsonify(api.get_model_info())

@app.route('/api/load_model', methods=['POST'])
def load_model():
    """Load a specific model"""
    data = request.json
    model_name = data.get('model_name')
    
    if not model_name:
        return jsonify({"error": "model_name is required"}), 400
    
    success = api.load_model(model_name)
    if success:
        return jsonify({"message": f"Model {model_name} loaded successfully"})
    else:
        return jsonify({"error": f"Failed to load model {model_name}"}), 500

@app.route('/api/generate', methods=['POST'])
def generate():
    """Generate text using specified model"""
    data = request.json
    model_name = data.get('model_name')
    prompt = data.get('prompt')
    
    if not model_name or not prompt:
        return jsonify({"error": "model_name and prompt are required"}), 400
    
    # Extract generation parameters
    params = {
        'max_length': data.get('max_length', 100),
        'temperature': data.get('temperature', 0.7),
        'top_p': data.get('top_p', 0.9),
        'do_sample': data.get('do_sample', True)
    }
    
    result = api.generate_text(model_name, prompt, **params)
    return jsonify(result)

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({"status": "healthy", "loaded_models": len(api.models)})

# Gradio Interface
def gradio_interface():
    def generate_text_ui(model_name, prompt, max_length, temperature, top_p):
        if not model_name or not prompt:
            return "Please select a model and enter a prompt"
        
        params = {
            'max_length': int(max_length),
            'temperature': float(temperature),
            'top_p': float(top_p),
            'do_sample': True
        }
        
        result = api.generate_text(model_name, prompt, **params)
        
        if 'error' in result:
            return f"Error: {result['error']}"
        
        return result.get('generated_text', str(result))
    
    def load_model_ui(model_name):
        if not model_name:
            return "Please select a model"
        
        success = api.load_model(model_name)
        if success:
            return f"✅ Model {model_name} loaded successfully"
        else:
            return f"❌ Failed to load model {model_name}"
    
    with gr.Blocks(title="Multi-Model API Interface") as interface:
        gr.Markdown("# Multi-Model API Interface")
        gr.Markdown("Load and interact with multiple Hugging Face models")
        
        with gr.Tab("Model Management"):
            model_dropdown = gr.Dropdown(
                choices=list(api.model_configs.keys()),
                label="Select Model",
                value=None
            )
            load_btn = gr.Button("Load Model")
            load_status = gr.Textbox(label="Status", interactive=False)
            
            load_btn.click(
                load_model_ui,
                inputs=[model_dropdown],
                outputs=[load_status]
            )
        
        with gr.Tab("Text Generation"):
            with gr.Row():
                with gr.Column():
                    gen_model = gr.Dropdown(
                        choices=list(api.model_configs.keys()),
                        label="Model",
                        value=None
                    )
                    prompt_input = gr.Textbox(
                        label="Prompt",
                        placeholder="Enter your prompt here...",
                        lines=3
                    )
                    
                    with gr.Row():
                        max_length = gr.Slider(10, 500, value=100, label="Max Length")
                        temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
                        top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P")
                    
                    generate_btn = gr.Button("Generate")
                
                with gr.Column():
                    output_text = gr.Textbox(
                        label="Generated Text",
                        lines=10,
                        interactive=False
                    )
            
            generate_btn.click(
                generate_text_ui,
                inputs=[gen_model, prompt_input, max_length, temperature, top_p],
                outputs=[output_text]
            )
        
        with gr.Tab("API Documentation"):
            gr.Markdown("""
            ## API Endpoints
            
            ### GET /api/models
            Get list of available and loaded models
            
            ### POST /api/load_model
            Load a specific model
            ```json
            {
                "model_name": "Lyon28/GPT-2"
            }
            ```
            
            ### POST /api/generate
            Generate text using a model
            ```json
            {
                "model_name": "Lyon28/GPT-2",
                "prompt": "Hello world",
                "max_length": 100,
                "temperature": 0.7,
                "top_p": 0.9,
                "do_sample": true
            }
            ```
            
            ### GET /health
            Health check endpoint
            """)
    
    return interface

def run_flask():
    """Run Flask API server"""
    app.run(host="0.0.0.0", port=5000, debug=False)

def main():
    """Main function to run both Flask and Gradio"""
    # Start Flask in a separate thread
    flask_thread = threading.Thread(target=run_flask, daemon=True)
    flask_thread.start()
    
    # Give Flask time to start
    time.sleep(2)
    
    # Create and launch Gradio interface
    interface = gradio_interface()
    
    # Launch Gradio on port 7860 (HF Spaces default)
    interface.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )

if __name__ == "__main__":
    main()