File size: 4,274 Bytes
fb7cb35
e879476
 
 
 
c06bfc1
 
 
 
289cf5a
c06bfc1
 
297c4f7
e879476
c016b94
c06bfc1
 
 
 
 
 
 
 
a4d044a
16578c7
289cf5a
 
 
 
c06bfc1
 
 
857e3f5
 
a4d044a
c06bfc1
 
 
 
 
 
 
 
e879476
 
 
 
c06bfc1
 
 
 
289cf5a
 
 
 
 
 
 
 
 
 
 
 
c06bfc1
 
 
 
 
 
 
 
289cf5a
c06bfc1
e879476
 
ee60681
a4d044a
c06bfc1
49ce3ab
 
e879476
a4d044a
 
16578c7
 
 
 
 
 
289cf5a
16578c7
 
289cf5a
 
16578c7
a4d044a
c06bfc1
289cf5a
c06bfc1
 
07257e8
d97d510
 
289cf5a
c06bfc1
289cf5a
07257e8
289cf5a
e879476
 
 
c06bfc1
 
 
 
857e3f5
 
 
a4d044a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Extra
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
import json
from typing import Optional, List, Union, Dict, Any
import asyncio

# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cpu"

# --- Chargement du modèle ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# On s'assure que le tokenizer a un token de padding.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Le pad_token a été défini sur eos_token.")

print("Modèle et tokenizer chargés avec succès sur le CPU.")

# --- Création de l'application API ---
app = FastAPI()

# --- Modèles de données (inchangés) ---
class ContentPart(BaseModel):
    type: str
    text: str

class ChatMessage(BaseModel):
    role: str
    content: Union[str, List[ContentPart]]

class ChatCompletionRequest(BaseModel):
    model: Optional[str] = None
    messages: List[ChatMessage]
    stream: Optional[bool] = False
    
    class Config:
        extra = Extra.ignore

class ChatCompletionResponseChoice(BaseModel):
    index: int = 0
    message: ChatMessage
    finish_reason: str = "stop"

class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[ChatCompletionResponseChoice]

class ModelData(BaseModel):
    id: str
    object: str = "model"
    owned_by: str = "user"

class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelData]
    
# --- Définition des API ---

@app.get("/models", response_model=ModelList)
async def list_models():
    """Répond à la requête GET /models pour satisfaire l'extension."""
    return ModelList(data=[ModelData(id=MODEL_ID)])

@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    """Endpoint principal qui gère la génération de texte en streaming."""
    
    # --- LA CORRECTION EST ICI ---
    # On convertit les messages de la requête en un format que le tokenizer peut utiliser.
    # C'est plus simple et plus robuste que de chercher le prompt manuellement.
    messages_for_model = [msg.dict() for msg in request.messages]

    # On applique le template. Le tokenizer de Qwen sait comment gérer cette structure.
    text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
    
    # On tokenize le texte pour obtenir explicitement input_ids ET attention_mask
    inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
    
    # On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire
    outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
    
    response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)

    async def stream_generator():
        response_id = f"chatcmpl-{uuid.uuid4()}"
        
        for char in response_text:
            chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None }] }
            yield f"data: {json.dumps(chunk)}\n\n"
            await asyncio.sleep(0.01)
        
        final_chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" }] }
        yield f"data: {json.dumps(final_chunk)}\n\n"
        yield "data: [DONE]\n\n"

    if request.stream:
        return StreamingResponse(stream_generator(), media_type="text/event-stream")
    else:
        return {"choices": [{"message": {"role": "assistant", "content": response_text}}]}

@app.get("/")
def root():
    return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}