update
Browse files
app.py
CHANGED
@@ -54,14 +54,13 @@ logger =logging.getLogger(__name__)
|
|
54 |
class PredictionRequest(BaseModel):
|
55 |
text: str = None # Texte personnalisé ajouté par l'utilisateur
|
56 |
# max_length: int = 2000 # Limite la longueur maximale du texte généré
|
57 |
-
|
58 |
@app.post("/predict/")
|
59 |
async def predict(request: PredictionRequest):
|
60 |
-
# Construire le prompt final
|
61 |
if request.text:
|
62 |
prompt = default_prompt + "\n\n" + request.text
|
63 |
else:
|
64 |
prompt = default_prompt
|
|
|
65 |
# Assurez-vous que le pad_token est défini
|
66 |
if tokenizer.pad_token is None:
|
67 |
tokenizer.pad_token = tokenizer.eos_token
|
@@ -93,7 +92,6 @@ async def predict(request: PredictionRequest):
|
|
93 |
max_length=3000, # Longueur maximale pour la génération
|
94 |
do_sample=True
|
95 |
)
|
96 |
-
|
97 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
98 |
|
99 |
return {"generated_text": generated_text}
|
|
|
54 |
class PredictionRequest(BaseModel):
|
55 |
text: str = None # Texte personnalisé ajouté par l'utilisateur
|
56 |
# max_length: int = 2000 # Limite la longueur maximale du texte généré
|
|
|
57 |
@app.post("/predict/")
|
58 |
async def predict(request: PredictionRequest):
|
|
|
59 |
if request.text:
|
60 |
prompt = default_prompt + "\n\n" + request.text
|
61 |
else:
|
62 |
prompt = default_prompt
|
63 |
+
|
64 |
# Assurez-vous que le pad_token est défini
|
65 |
if tokenizer.pad_token is None:
|
66 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
92 |
max_length=3000, # Longueur maximale pour la génération
|
93 |
do_sample=True
|
94 |
)
|
|
|
95 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
96 |
|
97 |
return {"generated_text": generated_text}
|