Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Extra | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import time | |
import uuid | |
import json | |
from typing import Optional, List, Union, Dict, Any | |
# --- Configuration --- | |
MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct" | |
DEVICE = "cpu" | |
# Qwen/Qwen3-1.7B | |
# deepseek-ai/deepseek-coder-1.3b-instruct | |
# Qwen/Qwen2.5-Coder-0.5B-Instruct | |
# --- Chargement du modèle --- | |
print(f"Début du chargement du modèle : {MODEL_ID}") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, | |
device_map=DEVICE | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print("Modèle et tokenizer chargés avec succès sur le CPU.") | |
# --- Création de l'application API --- | |
app = FastAPI() | |
# --- Modèles de données pour accepter la structure complexe de l'extension --- | |
class ContentPart(BaseModel): | |
type: str | |
text: str | |
class ChatMessage(BaseModel): | |
role: str | |
content: Union[str, List[ContentPart]] | |
class ChatCompletionRequest(BaseModel): | |
model: Optional[str] = None | |
messages: List[ChatMessage] | |
stream: Optional[bool] = False | |
class Config: | |
extra = Extra.ignore | |
class ModelData(BaseModel): | |
id: str | |
object: str = "model" | |
owned_by: str = "user" | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelData] | |
# --- Définition des API --- | |
async def list_models(): | |
"""Répond à la requête GET /models pour satisfaire l'extension.""" | |
return ModelList(data=[ModelData(id=MODEL_ID)]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
"""Endpoint principal qui gère la génération de texte en streaming.""" | |
# On extrait le prompt de l'utilisateur de la structure complexe | |
user_prompt = "" | |
last_message = request.messages[-1] | |
if isinstance(last_message.content, list): | |
for part in last_message.content: | |
if part.type == 'text': | |
user_prompt += part.text + "\n" | |
elif isinstance(last_message.content, str): | |
user_prompt = last_message.content | |
if not user_prompt: | |
return {"error": "Prompt non trouvé."} | |
# Préparation pour le modèle DeepSeek | |
messages_for_model = [{'role': 'user', 'content': user_prompt}] | |
inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE) | |
# Génération de la réponse complète | |
outputs = model.generate(inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) | |
response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) | |
# Fonction génératrice pour le streaming | |
async def stream_generator(): | |
response_id = f"chatcmpl-{uuid.uuid4()}" | |
# On envoie la réponse caractère par caractère, au format attendu | |
for char in response_text: | |
chunk = { | |
"id": response_id, | |
"object": "chat.completion.chunk", | |
"created": int(time.time()), | |
"model": MODEL_ID, | |
"choices": [{ | |
"index": 0, | |
"delta": {"content": char}, | |
"finish_reason": None | |
}] | |
} | |
yield f"data: {json.dumps(chunk)}\n\n" | |
await asyncio.sleep(0.01) # Petite pause pour simuler un flux | |
# On envoie le chunk final de fin | |
final_chunk = { | |
"id": response_id, | |
"object": "chat.completion.chunk", | |
"created": int(time.time()), | |
"model": MODEL_ID, | |
"choices": [{ | |
"index": 0, | |
"delta": {}, | |
"finish_reason": "stop" | |
}] | |
} | |
yield f"data: {json.dumps(final_chunk)}\n\n" | |
# On envoie le signal [DONE] | |
yield "data: [DONE]\n\n" | |
# Si l'extension demande un stream, on renvoie le générateur | |
if request.stream: | |
return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
else: | |
# Code de secours si le stream n'est pas demandé (peu probable) | |
return {"choices": [{"message": {"role": "assistant", "content": response_text}}]} | |
def root(): | |
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} | |
# On a besoin de asyncio pour la pause dans le stream | |
import asyncio | |
def root(): | |
return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} | |
# On a besoin de asyncio pour la pause dans le stream | |
import asyncio |