NickyNicky commited on
Commit
d470c63
·
verified ·
1 Parent(s): 0bc8427

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+
4
+
5
+ import gradio as gr
6
+
7
+ # !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
8
+ import json
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig
11
+ # from google.colab import userdata
12
+ import os
13
+
14
+ model_id = "somosnlp/Sam_Diagnostic"
15
+ bnb_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_compute_dtype=torch.bfloat16
19
+ )
20
+ max_seq_length=2048
21
+
22
+ # if torch.cuda.get_device_capability()[0] >= 8:
23
+ # # print("Flash Attention")
24
+ # attn_implementation="flash_attention_2"
25
+ # else:
26
+ # attn_implementation=None
27
+ attn_implementation=None
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(model_id,
30
+ max_length = max_seq_length)
31
+ model = AutoModelForCausalLM.from_pretrained(model_id,
32
+ # quantization_config=bnb_config,
33
+ device_map = {"":0},
34
+ attn_implementation = attn_implementation, # A100 o H100
35
+ ).eval()
36
+
37
+
38
+
39
+ class ListOfTokensStoppingCriteria(StoppingCriteria):
40
+ """
41
+ Clase para definir un criterio de parada basado en una lista de tokens específicos.
42
+ """
43
+ def __init__(self, tokenizer, stop_tokens):
44
+ self.tokenizer = tokenizer
45
+ # Codifica cada token de parada y guarda sus IDs en una lista
46
+ self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens]
47
+
48
+ def __call__(self, input_ids, scores, **kwargs):
49
+ # Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada
50
+ for stop_token_ids in self.stop_token_ids_list:
51
+ len_stop_tokens = len(stop_token_ids)
52
+ if len(input_ids[0]) >= len_stop_tokens:
53
+ if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids:
54
+ return True
55
+ return False
56
+
57
+ # Uso del criterio de parada personalizado
58
+ stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada
59
+
60
+ # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada
61
+ stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens)
62
+
63
+ # Añade tu criterio de parada a una StoppingCriteriaList
64
+ stopping_criteria_list = StoppingCriteriaList([stopping_criteria])
65
+
66
+ def generate_text(prompt, idioma_entrada, idioma_salida, max_length=2100):
67
+ prompt=prompt.replace(". ", ".\n").strip()
68
+
69
+ input_text = f'''<bos><start_of_turn>system
70
+ You are a helpful AI assistant.
71
+ Responde en formato json.
72
+ Eres un agente experto en medicina.
73
+ Lista de codigos linguisticos disponibles: ["{idioma_entrada}", "{idioma_salida}"]<end_of_turn>
74
+ <start_of_turn>user
75
+ {prompt}<end_of_turn>
76
+ <start_of_turn>model
77
+ '''
78
+
79
+ inputs = tokenizer.encode(input_text,
80
+ return_tensors="pt",
81
+ add_special_tokens=False).to("cuda:0")
82
+ max_new_tokens=max_length
83
+ generation_config = GenerationConfig(
84
+ max_new_tokens=max_new_tokens,
85
+ temperature=0.35, #55
86
+ #top_p=0.9,
87
+ top_k=50, # 45
88
+ repetition_penalty=1., #1.1
89
+ do_sample=True,
90
+ )
91
+ outputs = base_model.generate(generation_config=generation_config,
92
+ input_ids=inputs,
93
+ stopping_criteria=stopping_criteria_list,)
94
+ return tokenizer.decode(outputs[0], skip_special_tokens=False) #True
95
+
96
+ def mostrar_respuesta(pregunta, idioma_entrada, idioma_salida):
97
+ try:
98
+ lista_codigo_lin = {
99
+ "español": "es",
100
+ "ingles": "en",
101
+ }
102
+ # Utiliza los parámetros de idioma para obtener los códigos de idioma correspondientes.
103
+ codigo_lin_entrada = lista_codigo_lin[idioma_entrada.lower()]
104
+ codigo_lin_salida = lista_codigo_lin[idioma_salida.lower()]
105
+
106
+ res= generate_text(pregunta, codigo_lin_entrada, codigo_lin_salida, max_length=1500)
107
+ inicio_json = res.find('{')
108
+ fin_json = res.rfind('}') + 1
109
+ json_str = res[inicio_json:fin_json]
110
+ json_obj = json.loads(json_str)
111
+ return json_obj["description"], json_obj["medical_specialty"], json_obj["principal_diagnostic"]
112
+ except:
113
+ json_obj={}
114
+ json_obj['description']='Error diagnostico'
115
+ json_obj['medical_specialty']='Error diagnostico'
116
+ json_obj['principal_diagnostic']='Error diagnostico'
117
+ return json_obj
118
+
119
+ # Ejemplos de preguntas
120
+ ejemplos = [
121
+ ["CHIEF COMPLAINT:, Left wrist pain.,HISTORY OF PRESENT PROBLEM"],
122
+ ["INDICATIONS: ,Chest pain.,STRESS TECHNIQUE:,"],
123
+ ["MOTIVO DE CONSULTA: Una niña de 2 meses"],
124
+ ]
125
+
126
+ idiomas = ["español", "ingles"]
127
+
128
+
129
+ iface = gr.Interface(
130
+ fn=mostrar_respuesta,
131
+ inputs=[
132
+ gr.Textbox(label="Pregunta", placeholder="Introduce tu consulta médica aquí..."),
133
+ gr.Dropdown(label="Idioma de Entrada", choices=idiomas),
134
+ gr.Dropdown(label="Idioma de Salida", choices=idiomas),
135
+ ],
136
+ outputs=[
137
+ gr.Textbox(label="Description", lines=2),
138
+ gr.Textbox(label="Medical specialty", lines=1),
139
+ gr.Textbox(label="Principal diagnostic", lines=1)
140
+ ],
141
+ title="Consultas medicas",
142
+ description="Introduce tu diagnostico.",
143
+ examples=ejemplos,
144
+ concurrency_limit=20
145
+ )
146
+
147
+ iface.queue(max_size=14).launch(share=True,debug=True, ) # share=True,debug=True