Sam_Diagnostic / app.py
NickyNicky's picture
Update app.py
6217c50 verified
raw
history blame
5.71 kB
import gradio as gr
import json
import gradio as gr
# !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig
# from google.colab import userdata
import os
model_id = "somosnlp/Sam_Diagnostic"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
max_seq_length=2048
# if torch.cuda.get_device_capability()[0] >= 8:
# # print("Flash Attention")
# attn_implementation="flash_attention_2"
# else:
# attn_implementation=None
attn_implementation=None
tokenizer = AutoTokenizer.from_pretrained(model_id,
max_length = max_seq_length)
model = AutoModelForCausalLM.from_pretrained(model_id,
# quantization_config=bnb_config,
device_map = {"":0},
attn_implementation = attn_implementation, # A100 o H100
).eval()
class ListOfTokensStoppingCriteria(StoppingCriteria):
"""
Clase para definir un criterio de parada basado en una lista de tokens específicos.
"""
def __init__(self, tokenizer, stop_tokens):
self.tokenizer = tokenizer
# Codifica cada token de parada y guarda sus IDs en una lista
self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens]
def __call__(self, input_ids, scores, **kwargs):
# Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada
for stop_token_ids in self.stop_token_ids_list:
len_stop_tokens = len(stop_token_ids)
if len(input_ids[0]) >= len_stop_tokens:
if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids:
return True
return False
# Uso del criterio de parada personalizado
stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada
# Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
# Añade tu criterio de parada a una StoppingCriteriaList
stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
def generate_text(prompt, idioma_entrada, idioma_salida, max_length=2100):
prompt=prompt.replace(". ", ".\n").strip()
input_text = f'''<bos><start_of_turn>system
You are a helpful AI assistant.
Responde en formato json.
Eres un agente experto en medicina.
Lista de codigos linguisticos disponibles: ["{idioma_entrada}", "{idioma_salida}"]<end_of_turn>
<start_of_turn>user
{prompt}<end_of_turn>
<start_of_turn>model
'''
inputs = tokenizer.encode(input_text,
return_tensors="pt",
add_special_tokens=False).to("cuda:0")
max_new_tokens=max_length
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=0.35, #55
#top_p=0.9,
top_k=50, # 45
repetition_penalty=1., #1.1
do_sample=True,
)
outputs = model.generate(generation_config=generation_config,
input_ids=inputs,
stopping_criteria=stopping_criteria_list,)
return tokenizer.decode(outputs[0], skip_special_tokens=False) #True
def mostrar_respuesta(pregunta, idioma_entrada, idioma_salida):
try:
lista_codigo_lin = {
"español": "es",
"ingles": "en",
}
# Utiliza los parámetros de idioma para obtener los códigos de idioma correspondientes.
codigo_lin_entrada = lista_codigo_lin[idioma_entrada.lower()]
codigo_lin_salida = lista_codigo_lin[idioma_salida.lower()]
res= generate_text(pregunta, codigo_lin_entrada, codigo_lin_salida, max_length=1500)
inicio_json = res.find('{')
fin_json = res.rfind('}') + 1
json_str = res[inicio_json:fin_json]
json_obj = json.loads(json_str)
return json_obj["description"], json_obj["medical_specialty"], json_obj["principal_diagnostic"]
except:
json_obj={}
json_obj['description']='Error diagnostico'
json_obj['medical_specialty']='Error diagnostico'
json_obj['principal_diagnostic']='Error diagnostico'
return json_obj["description"], json_obj["medical_specialty"], json_obj["principal_diagnostic"]
# Ejemplos de preguntas
ejemplos = [
["CHIEF COMPLAINT:, Left wrist pain.,HISTORY OF PRESENT PROBLEM"],
["INDICATIONS: ,Chest pain.,STRESS TECHNIQUE:,"],
["MOTIVO DE CONSULTA: Una niña de 2 meses"],
]
idiomas = ["español", "ingles"]
iface = gr.Interface(
fn=mostrar_respuesta,
inputs=[
gr.Textbox(label="Pregunta", placeholder="Introduce tu consulta médica aquí..."),
gr.Dropdown(label="Idioma de Entrada", choices=idiomas, default="español"),
gr.Dropdown(label="Idioma de Salida", choices=idiomas, default="español"),
],
outputs=[
gr.Textbox(label="Description", lines=2),
gr.Textbox(label="Medical specialty", lines=1),
gr.Textbox(label="Principal diagnostic", lines=1)
],
title="Consultas medicas",
description="Introduce tu diagnostico.",
examples=ejemplos,
concurrency_limit=20
)
iface.queue(max_size=14).launch(share=True,debug=True, ) # share=True,debug=True