Spaces:
Sleeping
Sleeping
File size: 4,274 Bytes
fb7cb35 e879476 c06bfc1 289cf5a c06bfc1 297c4f7 e879476 c016b94 c06bfc1 a4d044a 16578c7 289cf5a c06bfc1 857e3f5 a4d044a c06bfc1 e879476 c06bfc1 289cf5a c06bfc1 289cf5a c06bfc1 e879476 ee60681 a4d044a c06bfc1 49ce3ab e879476 a4d044a 16578c7 289cf5a 16578c7 289cf5a 16578c7 a4d044a c06bfc1 289cf5a c06bfc1 07257e8 d97d510 289cf5a c06bfc1 289cf5a 07257e8 289cf5a e879476 c06bfc1 857e3f5 a4d044a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Extra
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
import json
from typing import Optional, List, Union, Dict, Any
import asyncio
# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cpu"
# --- Chargement du modèle ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# On s'assure que le tokenizer a un token de padding.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Le pad_token a été défini sur eos_token.")
print("Modèle et tokenizer chargés avec succès sur le CPU.")
# --- Création de l'application API ---
app = FastAPI()
# --- Modèles de données (inchangés) ---
class ContentPart(BaseModel):
type: str
text: str
class ChatMessage(BaseModel):
role: str
content: Union[str, List[ContentPart]]
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
stream: Optional[bool] = False
class Config:
extra = Extra.ignore
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
class ModelData(BaseModel):
id: str
object: str = "model"
owned_by: str = "user"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelData]
# --- Définition des API ---
@app.get("/models", response_model=ModelList)
async def list_models():
"""Répond à la requête GET /models pour satisfaire l'extension."""
return ModelList(data=[ModelData(id=MODEL_ID)])
@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Endpoint principal qui gère la génération de texte en streaming."""
# --- LA CORRECTION EST ICI ---
# On convertit les messages de la requête en un format que le tokenizer peut utiliser.
# C'est plus simple et plus robuste que de chercher le prompt manuellement.
messages_for_model = [msg.dict() for msg in request.messages]
# On applique le template. Le tokenizer de Qwen sait comment gérer cette structure.
text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
# On tokenize le texte pour obtenir explicitement input_ids ET attention_mask
inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
# On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire
outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
async def stream_generator():
response_id = f"chatcmpl-{uuid.uuid4()}"
for char in response_text:
chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None }] }
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.01)
final_chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" }] }
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
if request.stream:
return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
return {"choices": [{"message": {"role": "assistant", "content": response_text}}]}
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} |