smallagent / app.py
EnzGamers's picture
Update app.py
a8c2b2b verified
raw
history blame
5.66 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Extra
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
import json
from typing import Optional, List, Union, Dict, Any
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cpu"
# --- Chargement du modèle ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Modèle et tokenizer chargés avec succès sur le CPU.")
# --- Création de l'application API ---
app = FastAPI()
# --- Modèles de données pour accepter la structure complexe de l'extension ---
class ContentPart(BaseModel):
type: str
text: str
class ChatMessage(BaseModel):
role: str
content: Union[str, List[ContentPart]]
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
stream: Optional[bool] = False
class Config:
extra = Extra.ignore
class ModelData(BaseModel):
id: str
object: str = "model"
owned_by: str = "user"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelData]
# --- Définition des API ---
@app.get("/models", response_model=ModelList)
async def list_models():
"""Répond à la requête GET /models pour satisfaire l'extension."""
return ModelList(data=[ModelData(id=MODEL_ID)])
@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Endpoint principal qui gère la génération de texte en streaming."""
# On extrait le prompt de l'utilisateur de la structure complexe
user_prompt = ""
last_message = request.messages[-1]
if isinstance(last_message.content, list):
for part in last_message.content:
if part.type == 'text':
user_prompt += part.text + "\n"
elif isinstance(last_message.content, str):
user_prompt = last_message.content
if not user_prompt:
return {"error": "Prompt non trouvé."}
# Préparation pour le modèle DeepSeek
messages_for_model = [{'role': 'user', 'content': user_prompt}]
inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
# Génération de la réponse complète
outputs = model.generate(inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
# Fonction génératrice pour le streaming
async def stream_generator():
response_id = f"chatcmpl-{uuid.uuid4()}"
# On envoie la réponse caractère par caractère, au format attendu
for char in response_text:
chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {"content": char},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.01) # Petite pause pour simuler un flux
# On envoie le chunk final de fin
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
# On envoie le signal [DONE]
yield "data: [DONE]\n\n"
# Si l'extension demande un stream, on renvoie le générateur
if request.stream:
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
# Code de secours si le stream n'est pas demandé (peu probable)
return {"choices": [{"message": {"role": "assistant", "content": response_text}}]}
# ... (tout votre code existant reste inchangé) ...
# Fonction génératrice pour le streaming
async def stream_generator():
# ... (le contenu de cette fonction ne change pas) ...
# Si l'extension demande un stream, on renvoie le générateur
if request.stream:
# ... (cette partie ne change pas) ...
# ===============================================================
# AJOUTEZ LE CODE CI-DESSOUS
# ===============================================================
@app.post("/spend/calculate")
async def spend_calculate():
"""
Endpoint factice pour satisfaire le client qui essaie de calculer les coûts.
Ne fait rien et renvoie une réponse de succès vide.
"""
return {} # Renvoie un JSON vide avec un statut 200 OK par défaut
# ===============================================================
# FIN DE L'AJOUT
# ===============================================================
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}
# On a besoin de asyncio pour la pause dans le stream
import asyncio
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}
# On a besoin de asyncio pour la pause dans le stream
import asyncio