smallagent / app.py
EnzGamers's picture
Update app.py
ee9fb8a verified
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Extra
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
import json
from typing import Optional, List, Union, Dict, Any
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
DEVICE = "cpu"
# Qwen/Qwen3-1.7B
# deepseek-ai/deepseek-coder-1.3b-instruct
# Qwen/Qwen2.5-Coder-0.5B-Instruct
# --- Chargement du modèle ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Modèle et tokenizer chargés avec succès sur le CPU.")
# --- Création de l'application API ---
app = FastAPI()
# --- Modèles de données pour accepter la structure complexe de l'extension ---
class ContentPart(BaseModel):
type: str
text: str
class ChatMessage(BaseModel):
role: str
content: Union[str, List[ContentPart]]
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
stream: Optional[bool] = False
class Config:
extra = Extra.ignore
class ModelData(BaseModel):
id: str
object: str = "model"
owned_by: str = "user"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelData]
# --- Définition des API ---
@app.get("/models", response_model=ModelList)
async def list_models():
"""Répond à la requête GET /models pour satisfaire l'extension."""
return ModelList(data=[ModelData(id=MODEL_ID)])
@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Endpoint principal qui gère la génération de texte en streaming."""
# On extrait le prompt de l'utilisateur de la structure complexe
user_prompt = ""
last_message = request.messages[-1]
if isinstance(last_message.content, list):
for part in last_message.content:
if part.type == 'text':
user_prompt += part.text + "\n"
elif isinstance(last_message.content, str):
user_prompt = last_message.content
if not user_prompt:
return {"error": "Prompt non trouvé."}
# Préparation pour le modèle DeepSeek
messages_for_model = [{'role': 'user', 'content': user_prompt}]
inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
# Génération de la réponse complète
outputs = model.generate(inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
# Fonction génératrice pour le streaming
async def stream_generator():
response_id = f"chatcmpl-{uuid.uuid4()}"
# On envoie la réponse caractère par caractère, au format attendu
for char in response_text:
chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {"content": char},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.01) # Petite pause pour simuler un flux
# On envoie le chunk final de fin
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
# On envoie le signal [DONE]
yield "data: [DONE]\n\n"
# Si l'extension demande un stream, on renvoie le générateur
if request.stream:
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
# Code de secours si le stream n'est pas demandé (peu probable)
return {"choices": [{"message": {"role": "assistant", "content": response_text}}]}
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}
# On a besoin de asyncio pour la pause dans le stream
import asyncio
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}
# On a besoin de asyncio pour la pause dans le stream
import asyncio