EnzGamers commited on
Commit
16578c7
·
verified ·
1 Parent(s): 297c4f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -22
app.py CHANGED
@@ -22,8 +22,7 @@ model = AutoModelForCausalLM.from_pretrained(
22
  )
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
 
25
- # --- LA CORRECTION EST ICI (Partie 1) ---
26
- # On s'assure que le tokenizer a un token de padding. S'il n'en a pas, on utilise le token de fin de phrase.
27
  if tokenizer.pad_token is None:
28
  tokenizer.pad_token = tokenizer.eos_token
29
  print("Le pad_token a été défini sur eos_token.")
@@ -50,7 +49,6 @@ class ChatCompletionRequest(BaseModel):
50
  class Config:
51
  extra = Extra.ignore
52
 
53
- # ... (le reste des modèles de données est inchangé) ...
54
  class ChatCompletionResponseChoice(BaseModel):
55
  index: int = 0
56
  message: ChatMessage
@@ -83,30 +81,20 @@ async def list_models():
83
  async def create_chat_completion(request: ChatCompletionRequest):
84
  """Endpoint principal qui gère la génération de texte en streaming."""
85
 
86
- user_prompt = ""
87
- last_message = request.messages[-1]
88
- if isinstance(last_message.content, list):
89
- for part in last_message.content:
90
- if part.type == 'text':
91
- user_prompt += part.text + "\n"
92
- elif isinstance(last_message.content, str):
93
- user_prompt = last_message.content
94
-
95
- if not user_prompt:
96
- return {"error": "Prompt non trouvé."}
97
-
98
- messages_for_model = [{'role': 'user', 'content': user_prompt}]
99
-
100
- # --- LA CORRECTION EST ICI (Partie 2) ---
101
- # 1. On applique le template pour obtenir le texte brut
102
  text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
103
- # 2. On tokenize le texte pour obtenir explicitement input_ids ET attention_mask
 
104
  inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
105
 
106
- # 3. On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire (qui contient input_ids et attention_mask)
107
  outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
108
 
109
- # On doit maintenant décoder à partir des bons tokens
110
  response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
111
 
112
  async def stream_generator():
 
22
  )
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
 
25
+ # On s'assure que le tokenizer a un token de padding.
 
26
  if tokenizer.pad_token is None:
27
  tokenizer.pad_token = tokenizer.eos_token
28
  print("Le pad_token a été défini sur eos_token.")
 
49
  class Config:
50
  extra = Extra.ignore
51
 
 
52
  class ChatCompletionResponseChoice(BaseModel):
53
  index: int = 0
54
  message: ChatMessage
 
81
  async def create_chat_completion(request: ChatCompletionRequest):
82
  """Endpoint principal qui gère la génération de texte en streaming."""
83
 
84
+ # --- LA CORRECTION EST ICI ---
85
+ # On convertit les messages de la requête en un format que le tokenizer peut utiliser.
86
+ # C'est plus simple et plus robuste que de chercher le prompt manuellement.
87
+ messages_for_model = [msg.dict() for msg in request.messages]
88
+
89
+ # On applique le template. Le tokenizer de Qwen sait comment gérer cette structure.
 
 
 
 
 
 
 
 
 
 
90
  text_prompt = tokenizer.apply_chat_template(messages_for_model, tokenize=False, add_generation_prompt=True)
91
+
92
+ # On tokenize le texte pour obtenir explicitement input_ids ET attention_mask
93
  inputs = tokenizer(text_prompt, return_tensors="pt", padding=True).to(DEVICE)
94
 
95
+ # On passe les inputs au modèle en utilisant ** pour déballer le dictionnaire
96
  outputs = model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
97
 
 
98
  response_text = tokenizer.decode(outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
99
 
100
  async def stream_generator():