summaryT5 / app.py
bambadij's picture
change model to llm write
da80063
raw
history blame
4.19 kB
#load package
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer
)
from typing import List, Tuple
from threading import Thread
import os
from pydantic import BaseModel
import logging
import uvicorn
# Configurer les répertoires de cache
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
os.environ['HF_HOME'] = '/app/.cache'
# Charger le modèle et le tokenizer
model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)
#Additional information
Informations = """
-text : Texte à resumé
output:
- Text summary : texte resumé
"""
app =FastAPI(
title='Text Summary',
description =Informations
)
#class to define the input text
logging.basicConfig(level=logging.INFO)
logger =logging.getLogger(__name__)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
default_prompt = """Bonjour,
En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur :
1. **Informations Client** : Indique des détails pertinents sur le client.
2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.).
3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints).
Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant.
Merci !
"""
class PredictionRequest(BaseModel):
history: List[Tuple[str, str]] = []
prompt: str = ""
max_length: int = 128000
top_p: float = 0.8
temperature: float = 0.6
@app.post("/generate/")
async def predict(request: PredictionRequest):
history = default_prompt
prompt = request.prompt
max_length = request.max_length
top_p = request.top_p
temperature = request.temperature
stop = StopOnTokens()
messages = []
if prompt:
messages.append({"role": "system", "content": prompt})
for idx, (user_msg, model_msg) in enumerate(history):
if prompt and idx == 0:
continue
if idx == len(history) - 1 and not model_msg:
query = user_msg
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1,
"eos_token_id": eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
generated_text = ""
for new_token in streamer:
if new_token and '<|user|>' in new_token:
new_token = new_token.split('<|user|>')[0]
if new_token:
generated_text += new_token
history[-1][1] = generated_text
return {"history": history}
if __name__ == "__main__":
uvicorn.run("app:app",reload=True)