File size: 4,855 Bytes
fb7cb35
e879476
 
 
 
c06bfc1
 
 
 
4aedab8
57907bc
ee9fb8a
e879476
2cce33e
 
c08ef1f
c06bfc1
 
 
 
 
 
 
 
 
 
 
857e3f5
 
57907bc
c06bfc1
 
 
 
 
 
 
 
e879476
 
 
 
c06bfc1
 
 
 
 
 
 
 
 
 
 
 
57907bc
c06bfc1
e879476
 
ee60681
a4d044a
c06bfc1
49ce3ab
 
e879476
a4d044a
 
57907bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289cf5a
57907bc
 
 
c06bfc1
57907bc
c06bfc1
07257e8
d97d510
57907bc
d97d510
57907bc
 
 
 
 
 
 
 
 
 
 
c06bfc1
57907bc
07257e8
57907bc
 
 
 
 
 
 
 
 
 
 
 
e879476
57907bc
 
e879476
 
57907bc
c06bfc1
 
 
57907bc
c06bfc1
7c0160e
 
 
 
 
 
 
857e3f5
 
 
57907bc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Extra
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
import json
from typing import Optional, List, Union, Dict, Any

# --- Configuration ---
MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
DEVICE = "cpu"
# Qwen/Qwen3-1.7B
# deepseek-ai/deepseek-coder-1.3b-instruct
# Qwen/Qwen2.5-Coder-0.5B-Instruct 
# --- Chargement du modèle ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Modèle et tokenizer chargés avec succès sur le CPU.")

# --- Création de l'application API ---
app = FastAPI()

# --- Modèles de données pour accepter la structure complexe de l'extension ---
class ContentPart(BaseModel):
    type: str
    text: str

class ChatMessage(BaseModel):
    role: str
    content: Union[str, List[ContentPart]]

class ChatCompletionRequest(BaseModel):
    model: Optional[str] = None
    messages: List[ChatMessage]
    stream: Optional[bool] = False
    
    class Config:
        extra = Extra.ignore

class ModelData(BaseModel):
    id: str
    object: str = "model"
    owned_by: str = "user"

class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelData]

# --- Définition des API ---

@app.get("/models", response_model=ModelList)
async def list_models():
    """Répond à la requête GET /models pour satisfaire l'extension."""
    return ModelList(data=[ModelData(id=MODEL_ID)])

@app.post("/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    """Endpoint principal qui gère la génération de texte en streaming."""
    
    # On extrait le prompt de l'utilisateur de la structure complexe
    user_prompt = ""
    last_message = request.messages[-1]
    if isinstance(last_message.content, list):
        for part in last_message.content:
            if part.type == 'text':
                user_prompt += part.text + "\n"
    elif isinstance(last_message.content, str):
        user_prompt = last_message.content

    if not user_prompt:
        return {"error": "Prompt non trouvé."}

    # Préparation pour le modèle DeepSeek
    messages_for_model = [{'role': 'user', 'content': user_prompt}]
    inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
    
    # Génération de la réponse complète
    outputs = model.generate(inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
    response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)

    # Fonction génératrice pour le streaming
    async def stream_generator():
        response_id = f"chatcmpl-{uuid.uuid4()}"
        
        # On envoie la réponse caractère par caractère, au format attendu
        for char in response_text:
            chunk = {
                "id": response_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": MODEL_ID,
                "choices": [{
                    "index": 0,
                    "delta": {"content": char},
                    "finish_reason": None
                }]
            }
            yield f"data: {json.dumps(chunk)}\n\n"
            await asyncio.sleep(0.01) # Petite pause pour simuler un flux
        
        # On envoie le chunk final de fin
        final_chunk = {
            "id": response_id,
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": MODEL_ID,
            "choices": [{
                "index": 0,
                "delta": {},
                "finish_reason": "stop"
            }]
        }
        yield f"data: {json.dumps(final_chunk)}\n\n"
        
        # On envoie le signal [DONE]
        yield "data: [DONE]\n\n"

    # Si l'extension demande un stream, on renvoie le générateur
    if request.stream:
        return StreamingResponse(stream_generator(), media_type="text/event-stream")
    else:
        # Code de secours si le stream n'est pas demandé (peu probable)
        return {"choices": [{"message": {"role": "assistant", "content": response_text}}]}

@app.get("/")
def root():
    return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}

# On a besoin de asyncio pour la pause dans le stream
import asyncio

@app.get("/")
def root():
    return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}

# On a besoin de asyncio pour la pause dans le stream
import asyncio