smallagent / app.py
EnzGamers's picture
Update app.py
16578c7 verified
raw
history blame
4.27 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Extra
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
import json
from typing import Optional, List, Union, Dict, Any
import asyncio
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cpu"
# --- Chargement du modèle ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# On s'assure que le tokenizer a un token de padding.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Le pad_token a été défini sur eos_token.")
print("Modèle et tokenizer chargés avec succès sur le CPU.")
# --- Création de l'application API ---
app = FastAPI()
# --- Modèles de données (inchangés) ---
class ContentPart(BaseModel):
type: str
text: str
class ChatMessage(BaseModel):
role: str
content: Union[str, List[ContentPart]]
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
stream: Optional[bool] = False
class Config:
extra = Extra.ignore
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
class ModelData(BaseModel):
id: str
object: str = "model"
owned_by: str = "user"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelData]
# --- Définition des API ---
@app.get("/models", response_model=ModelList)
async def list_models():
"""Répond à la requête GET /models pour satisfaire l'extension."""
return ModelList(data=[ModelData(id=MODEL_ID)])
@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Endpoint principal qui gère la génération de texte en streaming."""
# --- LA CORRECTION EST ICI ---
# On convertit les messages de la requête en un format que le tokenizer peut utiliser.
# C'est plus simple et plus robuste que de chercher le prompt manuellement.
messages_for_model = [msg.dict() for msg in request.messages]
# On applique le template. Le tokenizer de Qwen sait comment gérer cette structure.
text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
# On tokenize le texte pour obtenir explicitement input_ids ET attention_mask
inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
# On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire
outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
async def stream_generator():
response_id = f"chatcmpl-{uuid.uuid4()}"
for char in response_text:
chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None }] }
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.01)
final_chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" }] }
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
if request.stream:
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
return {"choices": [{"message": {"role": "assistant", "content": response_text}}]}
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}