Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Extra | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import time | |
import uuid | |
import json | |
from typing import Optional, List, Union, Dict, Any | |
import asyncio | |
# --- Configuration --- | |
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | |
DEVICE = "cpu" | |
# --- Chargement du modèle --- | |
print(f"Début du chargement du modèle : {MODEL_ID}") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
device_map=DEVICE | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
# On s'assure que le tokenizer a un token de padding. | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Le pad_token a été défini sur eos_token.") | |
print("Modèle et tokenizer chargés avec succès sur le CPU.") | |
# --- Création de l'application API --- | |
app = FastAPI() | |
# --- Modèles de données (inchangés) --- | |
class ContentPart(BaseModel): | |
type: str | |
text: str | |
class ChatMessage(BaseModel): | |
role: str | |
content: Union[str, List[ContentPart]] | |
class ChatCompletionRequest(BaseModel): | |
model: Optional[str] = None | |
messages: List[ChatMessage] | |
stream: Optional[bool] = False | |
class Config: | |
extra = Extra.ignore | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int = 0 | |
message: ChatMessage | |
finish_reason: str = "stop" | |
class ChatCompletionResponse(BaseModel): | |
id: str | |
object: str = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatCompletionResponseChoice] | |
class ModelData(BaseModel): | |
id: str | |
object: str = "model" | |
owned_by: str = "user" | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelData] | |
# --- Définition des API --- | |
async def list_models(): | |
"""Répond à la requête GET /models pour satisfaire l'extension.""" | |
return ModelList(data=[ModelData(id=MODEL_ID)]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
"""Endpoint principal qui gère la génération de texte en streaming.""" | |
# --- LA CORRECTION EST ICI --- | |
# On convertit les messages de la requête en un format que le tokenizer peut utiliser. | |
# C'est plus simple et plus robuste que de chercher le prompt manuellement. | |
messages_for_model = [msg.dict() for msg in request.messages] | |
# On applique le template. Le tokenizer de Qwen sait comment gérer cette structure. | |
text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True) | |
# On tokenize le texte pour obtenir explicitement input_ids ET attention_mask | |
inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE) | |
# On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire | |
outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) | |
response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
async def stream_generator(): | |
response_id = f"chatcmpl-{uuid.uuid4()}" | |
for char in response_text: | |
chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None }] } | |
yield f"data: {json.dumps(chunk)}\n\n" | |
await asyncio.sleep(0.01) | |
final_chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" }] } | |
yield f"data: {json.dumps(final_chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
if request.stream: | |
return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
else: | |
return {"choices": [{"message": {"role": "assistant", "content": response_text}}]} | |
def root(): | |
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} |