EnzGamers commited on
Commit
a4d044a
·
verified ·
1 Parent(s): 289cf5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -21,6 +21,9 @@ model = AutoModelForCausalLM.from_pretrained(
21
  device_map=DEVICE
22
  )
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
24
  if tokenizer.pad_token is None:
25
  tokenizer.pad_token = tokenizer.eos_token
26
  print("Le pad_token a été défini sur eos_token.")
@@ -30,7 +33,7 @@ print("Modèle et tokenizer chargés avec succès sur le CPU.")
30
  # --- Création de l'application API ---
31
  app = FastAPI()
32
 
33
- # --- Modèles de données ---
34
  class ContentPart(BaseModel):
35
  type: str
36
  text: str
@@ -43,13 +46,6 @@ class ChatCompletionRequest(BaseModel):
43
  model: Optional[str] = None
44
  messages: List[ChatMessage]
45
  stream: Optional[bool] = False
46
- max_tokens: Optional[int] = 512 # Augmenté pour des réponses plus longues
47
-
48
- # --- LES NOUVEAUX CHAMPS SONT ICI ---
49
- # Ajout des paramètres de génération avec des valeurs par défaut.
50
- temperature: Optional[float] = 0.4
51
- top_p: Optional[float] = 0.95
52
- top_k: Optional[int] = 50
53
 
54
  class Config:
55
  extra = Extra.ignore
@@ -80,10 +76,13 @@ class ModelList(BaseModel):
80
 
81
  @app.get("/models", response_model=ModelList)
82
  async def list_models():
 
83
  return ModelList(data=[ModelData(id=MODEL_ID)])
84
 
85
  @app.post("/chat/completions")
86
  async def create_chat_completion(request: ChatCompletionRequest):
 
 
87
  user_prompt = ""
88
  last_message = request.messages[-1]
89
  if isinstance(last_message.content, list):
@@ -97,21 +96,17 @@ async def create_chat_completion(request: ChatCompletionRequest):
97
  return {"error": "Prompt non trouvé."}
98
 
99
  messages_for_model = [{'role': 'user', 'content': user_prompt}]
 
 
 
100
  text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
 
101
  inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
102
 
103
- # --- LA MISE À JOUR EST ICI ---
104
- # On utilise maintenant les paramètres de la requête pour la génération.
105
- outputs = model.generate(
106
- **inputs,
107
- max_new_tokens=request.max_tokens,
108
- do_sample=True, # do_sample doit être True pour que temp, top_p et top_k aient un effet
109
- temperature=request.temperature,
110
- top_p=request.top_p,
111
- top_k=request.top_k,
112
- eos_token_id=tokenizer.eos_token_id
113
- )
114
 
 
115
  response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
116
 
117
  async def stream_generator():
@@ -133,4 +128,4 @@ async def create_chat_completion(request: ChatCompletionRequest):
133
 
134
  @app.get("/")
135
  def root():
136
- return {"status": "API compatible OpenAI en ligne (avec streaming et paramètres dynamiques)", "model_id": MODEL_ID}
 
21
  device_map=DEVICE
22
  )
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
+
25
+ # --- LA CORRECTION EST ICI (Partie 1) ---
26
+ # On s'assure que le tokenizer a un token de padding. S'il n'en a pas, on utilise le token de fin de phrase.
27
  if tokenizer.pad_token is None:
28
  tokenizer.pad_token = tokenizer.eos_token
29
  print("Le pad_token a été défini sur eos_token.")
 
33
  # --- Création de l'application API ---
34
  app = FastAPI()
35
 
36
+ # --- Modèles de données (inchangés) ---
37
  class ContentPart(BaseModel):
38
  type: str
39
  text: str
 
46
  model: Optional[str] = None
47
  messages: List[ChatMessage]
48
  stream: Optional[bool] = False
 
 
 
 
 
 
 
49
 
50
  class Config:
51
  extra = Extra.ignore
 
76
 
77
  @app.get("/models", response_model=ModelList)
78
  async def list_models():
79
+ """Répond à la requête GET /models pour satisfaire l'extension."""
80
  return ModelList(data=[ModelData(id=MODEL_ID)])
81
 
82
  @app.post("/chat/completions")
83
  async def create_chat_completion(request: ChatCompletionRequest):
84
+ """Endpoint principal qui gère la génération de texte en streaming."""
85
+
86
  user_prompt = ""
87
  last_message = request.messages[-1]
88
  if isinstance(last_message.content, list):
 
96
  return {"error": "Prompt non trouvé."}
97
 
98
  messages_for_model = [{'role': 'user', 'content': user_prompt}]
99
+
100
+ # --- LA CORRECTION EST ICI (Partie 2) ---
101
+ # 1. On applique le template pour obtenir le texte brut
102
  text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
103
+ # 2. On tokenize le texte pour obtenir explicitement input_ids ET attention_mask
104
  inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
105
 
106
+ # 3. On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire (qui contient input_ids et attention_mask)
107
+ outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
 
108
 
109
+ # On doit maintenant décoder à partir des bons tokens
110
  response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
111
 
112
  async def stream_generator():
 
128
 
129
  @app.get("/")
130
  def root():
131
+ return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID}