Spaces:
Running
Running
import os | |
import uvicorn | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from typing import Optional, Dict | |
import time | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Inisialisasi FastAPI | |
app = FastAPI(title="LyonPoy AI Chat - Optimized") | |
# Optimized model configuration - prioritize smaller, faster models | |
MODELS = { | |
"distil-gpt-2": { | |
"name": "DistilGPT-2", | |
"model_path": "Lyon28/Distil_GPT-2", | |
"task": "text-generation", | |
"priority": 1 # Highest priority - smallest model | |
}, | |
"gpt-2-tinny": { | |
"name": "GPT-2 Tinny", | |
"model_path": "Lyon28/GPT-2-Tinny", | |
"task": "text-generation", | |
"priority": 2 | |
}, | |
"tinny-llama": { | |
"name": "Tinny Llama", | |
"model_path": "Lyon28/Tinny-Llama", | |
"task": "text-generation", | |
"priority": 3 | |
}, | |
"gpt-2": { | |
"name": "GPT-2", | |
"model_path": "Lyon28/GPT-2", | |
"task": "text-generation", | |
"priority": 4 | |
}, | |
"bert-tinny": { | |
"name": "BERT Tinny", | |
"model_path": "Lyon28/Bert-Tinny", | |
"task": "text-classification", | |
"priority": 5 | |
}, | |
"albert-base-v2": { | |
"name": "ALBERT Base V2", | |
"model_path": "Lyon28/Albert-Base-V2", | |
"task": "text-classification", | |
"priority": 6 | |
}, | |
"distilbert-base-uncased": { | |
"name": "DistilBERT", | |
"model_path": "Lyon28/Distilbert-Base-Uncased", | |
"task": "text-classification", | |
"priority": 7 | |
}, | |
"electra-small": { | |
"name": "ELECTRA Small", | |
"model_path": "Lyon28/Electra-Small", | |
"task": "text-classification", | |
"priority": 8 | |
}, | |
"t5-small": { | |
"name": "T5 Small", | |
"model_path": "Lyon28/T5-Small", | |
"task": "text2text-generation", | |
"priority": 9 | |
}, | |
"pythia": { | |
"name": "Pythia", | |
"model_path": "Lyon28/Pythia", | |
"task": "text-generation", | |
"priority": 10 | |
}, | |
"gpt-neo": { | |
"name": "GPT-Neo", | |
"model_path": "Lyon28/GPT-Neo", | |
"task": "text-generation", | |
"priority": 11 # Largest model - lowest priority | |
} | |
} | |
class ChatRequest(BaseModel): | |
message: str | |
model: Optional[str] = "distil-gpt-2" # Default to fastest model | |
# Global state | |
app.state.pipelines = {} | |
app.state.loading_models = set() | |
app.state.executor = ThreadPoolExecutor(max_workers=2) | |
# Optimized model loading | |
async def load_model_async(model_id: str): | |
"""Load model in background thread""" | |
if model_id in app.state.loading_models: | |
return False | |
app.state.loading_models.add(model_id) | |
try: | |
model_config = MODELS[model_id] | |
logger.info(f"π Loading {model_config['name']}...") | |
# Load in thread to avoid blocking | |
loop = asyncio.get_event_loop() | |
def load_model(): | |
device = 0 if torch.cuda.is_available() else -1 | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
return pipeline( | |
task=model_config["task"], | |
model=model_config["model_path"], | |
device=device, | |
torch_dtype=dtype, | |
use_fast=True, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
# Optimization for faster inference | |
pad_token_id=50256 if "gpt" in model_id else None | |
) | |
pipeline_obj = await loop.run_in_executor(app.state.executor, load_model) | |
app.state.pipelines[model_id] = pipeline_obj | |
logger.info(f"β {model_config['name']} loaded successfully") | |
return True | |
except Exception as e: | |
logger.error(f"β Failed to load {model_id}: {e}") | |
return False | |
finally: | |
app.state.loading_models.discard(model_id) | |
async def load_models(): | |
"""Load high-priority models on startup""" | |
os.environ['HF_HOME'] = './cache/huggingface' # Persistent cache | |
os.makedirs(os.environ['HF_HOME'], exist_ok=True) | |
# Pre-load top 3 fastest models | |
priority_models = sorted(MODELS.keys(), key=lambda x: MODELS[x]['priority'])[:3] | |
tasks = [] | |
for model_id in priority_models: | |
task = asyncio.create_task(load_model_async(model_id)) | |
tasks.append(task) | |
# Load models concurrently | |
await asyncio.gather(*tasks, return_exceptions=True) | |
logger.info("π LyonPoy AI Chat Ready!") | |
# Optimized inference | |
async def run_inference(model_id: str, message: str): | |
"""Run inference in background thread""" | |
if model_id not in app.state.pipelines: | |
# Try to load model if not available | |
success = await load_model_async(model_id) | |
if not success: | |
raise HTTPException(status_code=503, detail=f"Model {model_id} unavailable") | |
pipe = app.state.pipelines[model_id] | |
model_config = MODELS[model_id] | |
loop = asyncio.get_event_loop() | |
def inference(): | |
start_time = time.time() | |
try: | |
if model_config["task"] == "text-generation": | |
# Optimized generation parameters | |
result = pipe( | |
message, | |
max_new_tokens=min(50, 150 - len(message.split())), # Shorter responses | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
top_k=50, | |
repetition_penalty=1.1, | |
pad_token_id=pipe.tokenizer.eos_token_id if hasattr(pipe.tokenizer, 'eos_token_id') else 50256 | |
)[0]['generated_text'] | |
# Clean output | |
if result.startswith(message): | |
result = result[len(message):].strip() | |
# Limit response length | |
if len(result) > 200: | |
result = result[:200] + "..." | |
elif model_config["task"] == "text-classification": | |
output = pipe(message)[0] | |
result = f"Analisis: {output['label']} (Keyakinan: {output['score']:.2f})" | |
elif model_config["task"] == "text2text-generation": | |
result = pipe(message, max_length=100, num_beams=2)[0]['generated_text'] | |
inference_time = time.time() - start_time | |
logger.info(f"β‘ Inference time: {inference_time:.2f}s for {model_config['name']}") | |
return result | |
except Exception as e: | |
logger.error(f"Inference error: {e}") | |
raise e | |
return await loop.run_in_executor(app.state.executor, inference) | |
# Frontend route - simplified HTML | |
async def get_frontend(): | |
html_content = ''' | |
<!DOCTYPE html> | |
<html lang="id"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>LyonPoy AI Chat - Fast Mode</title> | |
<style> | |
* { margin: 0; padding: 0; box-sizing: border-box; } | |
body { font-family: system-ui; background: #f5f5f5; padding: 20px; } | |
.container { max-width: 600px; margin: 0 auto; background: white; border-radius: 10px; overflow: hidden; } | |
.header { background: #007bff; color: white; padding: 15px; } | |
.chat { height: 400px; overflow-y: auto; padding: 15px; background: #fafafa; } | |
.message { margin: 10px 0; padding: 8px 12px; border-radius: 8px; } | |
.user { background: #007bff; color: white; margin-left: 20%; } | |
.bot { background: white; border: 1px solid #ddd; margin-right: 20%; } | |
.input-area { padding: 15px; display: flex; gap: 10px; } | |
input { flex: 1; padding: 10px; border: 1px solid #ddd; border-radius: 5px; } | |
button { padding: 10px 15px; background: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; } | |
select { padding: 5px; margin-left: 10px; } | |
.loading { color: #666; font-style: italic; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<div class="header"> | |
<h1>π LyonPoy AI - Fast Mode</h1> | |
<select id="model"> | |
<option value="distil-gpt-2">DistilGPT-2 (Fastest)</option> | |
<option value="gpt-2-tinny">GPT-2 Tinny</option> | |
<option value="tinny-llama">Tinny Llama</option> | |
<option value="gpt-2">GPT-2</option> | |
<option value="bert-tinny">BERT Tinny</option> | |
<option value="albert-base-v2">ALBERT Base V2</option> | |
<option value="distilbert-base-uncased">DistilBERT</option> | |
<option value="electra-small">ELECTRA Small</option> | |
<option value="t5-small">T5 Small</option> | |
<option value="pythia">Pythia</option> | |
<option value="gpt-neo">GPT-Neo (Slowest)</option> | |
</select> | |
</div> | |
<div class="chat" id="chat"></div> | |
<div class="input-area"> | |
<input type="text" id="message" placeholder="Ketik pesan..." maxlength="200"> | |
<button onclick="sendMessage()">Kirim</button> | |
</div> | |
</div> | |
<script> | |
const chat = document.getElementById('chat'); | |
const messageInput = document.getElementById('message'); | |
const modelSelect = document.getElementById('model'); | |
function addMessage(content, isUser = false) { | |
const div = document.createElement('div'); | |
div.className = `message ${isUser ? 'user' : 'bot'}`; | |
div.textContent = content; | |
chat.appendChild(div); | |
chat.scrollTop = chat.scrollHeight; | |
} | |
async function sendMessage() { | |
const message = messageInput.value.trim(); | |
if (!message) return; | |
addMessage(message, true); | |
messageInput.value = ''; | |
addMessage('β³ Thinking...', false); | |
const startTime = Date.now(); | |
try { | |
const response = await fetch('/chat', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ | |
message: message, | |
model: modelSelect.value | |
}) | |
}); | |
const data = await response.json(); | |
const responseTime = ((Date.now() - startTime) / 1000).toFixed(1); | |
// Remove loading message | |
chat.removeChild(chat.lastElementChild); | |
if (data.status === 'success') { | |
addMessage(`${data.response} (${responseTime}s)`, false); | |
} else { | |
addMessage('β Error occurred', false); | |
} | |
} catch (error) { | |
chat.removeChild(chat.lastElementChild); | |
addMessage('β Connection error', false); | |
} | |
} | |
messageInput.addEventListener('keypress', (e) => { | |
if (e.key === 'Enter') sendMessage(); | |
}); | |
// Show welcome message | |
addMessage('π Halo! Pilih model dan mulai chat. Model DistilGPT-2 paling cepat!', false); | |
</script> | |
</body> | |
</html> | |
''' | |
return HTMLResponse(content=html_content) | |
# Optimized chat endpoint | |
async def chat(request: ChatRequest, background_tasks: BackgroundTasks): | |
try: | |
model_id = request.model.lower() | |
if model_id not in MODELS: | |
raise HTTPException(status_code=400, detail="Model tidak tersedia") | |
# Limit message length for faster processing | |
message = request.message[:200] # Max 200 chars | |
# Run inference | |
result = await run_inference(model_id, message) | |
# Load next priority model in background | |
background_tasks.add_task(preload_next_model, model_id) | |
return { | |
"response": result, | |
"model": MODELS[model_id]["name"], | |
"status": "success" | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
raise HTTPException(status_code=500, detail="Terjadi kesalahan") | |
async def preload_next_model(current_model: str): | |
"""Preload next model in background""" | |
try: | |
# Find next unloaded model by priority | |
loaded_models = set(app.state.pipelines.keys()) | |
all_models = sorted(MODELS.keys(), key=lambda x: MODELS[x]['priority']) | |
for model_id in all_models: | |
if model_id not in loaded_models and model_id not in app.state.loading_models: | |
await load_model_async(model_id) | |
break | |
except Exception as e: | |
logger.error(f"Background loading error: {e}") | |
# Health check with model status | |
async def health(): | |
loaded_models = list(app.state.pipelines.keys()) | |
return { | |
"status": "healthy", | |
"gpu": torch.cuda.is_available(), | |
"loaded_models": loaded_models, | |
"loading_models": list(app.state.loading_models) | |
} | |
# Model status endpoint | |
async def get_models(): | |
models_status = {} | |
for model_id, config in MODELS.items(): | |
models_status[model_id] = { | |
"name": config["name"], | |
"loaded": model_id in app.state.pipelines, | |
"loading": model_id in app.state.loading_models, | |
"priority": config["priority"] | |
} | |
return models_status | |
# Cleanup on shutdown | |
async def cleanup(): | |
app.state.executor.shutdown(wait=True) | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=port, | |
log_level="info", | |
access_log=False # Disable access log for better performance | |
) |