File size: 5,226 Bytes
5a798cc
 
4251465
 
 
 
 
 
820a0dd
 
c3a5bd0
a58d64d
820a0dd
 
 
 
 
 
 
 
 
 
 
 
5a798cc
b956a25
 
 
 
 
a58d64d
 
 
 
 
 
165628b
b956a25
 
 
165628b
38e08fc
165628b
 
a58d64d
38e08fc
b956a25
165628b
 
 
b956a25
165628b
b956a25
 
4251465
 
165628b
4251465
 
 
 
 
 
 
165628b
820a0dd
165628b
5a798cc
4251465
5a798cc
c3a5bd0
3d55aee
165628b
3d55aee
c3a5bd0
165628b
c3a5bd0
 
 
 
165628b
 
c3a5bd0
 
 
 
165628b
 
e0ac11d
 
165628b
4251465
 
 
 
 
 
 
 
 
c3a5bd0
165628b
c3a5bd0
165628b
 
5a798cc
 
165628b
 
 
4251465
165628b
 
5a798cc
165628b
 
4251465
165628b
 
 
4251465
165628b
5a798cc
 
 
 
4251465
5a798cc
 
4251465
5a798cc
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    MarianMTModel,
    MarianTokenizer
)
import time
from functools import wraps
import sys
import os

# Decorador para medir el tiempo de ejecución
def medir_tiempo(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        inicio = time.time()
        resultado = func(*args, **kwargs)
        fin = time.time()
        tiempo_transcurrido = fin - inicio
        print(f"Tiempo de ejecución de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
        return resultado
    return wrapper

# Verificar si CUDA está disponible
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")

# Obtener el token de Hugging Face desde las variables de entorno
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if not hf_token:
    print("Error: El token de Hugging Face no está configurado en los secretos.")
    sys.exit(1)

# Cargar el tokenizador y el modelo de generación desde HuggingFace
model_name = "dmis-lab/selfbiorag_7b"

try:
    print("Cargando el tokenizador y el modelo de generación desde HuggingFace...")
    tokenizer_gen = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
    model_gen = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        use_auth_token=hf_token
    ).to(device)
except ValueError as e:
    print(f"Error al cargar el tokenizador de generación: {e}")
    sys.exit(1)
except Exception as e:
    print(f"Error al cargar el modelo de generación: {e}")
    sys.exit(1)

# Definir el modelo de traducción al Español
translation_model_name = "Helsinki-NLP/opus-mt-en-es"

try:
    print(f"Cargando el tokenizador y el modelo de traducción para Español desde HuggingFace...")
    tokenizer_tr_es = MarianTokenizer.from_pretrained(translation_model_name)
    model_tr_es = MarianMTModel.from_pretrained(translation_model_name).to(device)
except Exception as e:
    print(f"Error al cargar el modelo de traducción para Español: {e}")
    sys.exit(1)

@medir_tiempo
def generar_y_traducir_respuesta(consulta, idioma_destino):
    """
    Función que genera una respuesta a partir de una consulta dada y la traduce al Español.
    """
    try:
        if not consulta.strip():
            return "Por favor, ingresa una consulta válida.", ""
        
        # Tokenizar la consulta
        inputs = tokenizer_gen.encode(consulta, return_tensors="pt").to(device)
        
        # Configurar los parámetros de generación
        generation_kwargs = {
            "max_new_tokens": 100,  # Ajustado a 100
            "do_sample": False       # Generación determinista
            # Puedes añadir otros parámetros como 'num_beams' si lo deseas
        }
        
        # Generar la respuesta
        with torch.no_grad():
            outputs = model_gen.generate(input_ids=inputs, **generation_kwargs)
        
        # Decodificar la respuesta en inglés con limpieza de espacios
        respuesta_en = tokenizer_gen.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        
        # Traducir la respuesta al Español
        traducir_inputs = tokenizer_tr_es.encode(respuesta_en, return_tensors="pt").to(device)
        
        # Realizar la traducción
        with torch.no_grad():
            traduccion_outputs = model_tr_es.generate(input_ids=traducir_inputs, max_length=512)
        
        # Decodificar la traducción con limpieza de espacios
        respuesta_traducida = tokenizer_tr_es.decode(traduccion_outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        
        return respuesta_en, respuesta_traducida
    except Exception as e:
        print(f"Error durante la generación o traducción de la respuesta: {e}")
        return f"Error al generar la respuesta: {e}", ""

# Definir la interfaz de Gradio
titulo = "Generador y Traductor de Respuestas con SelfBioRAG 7B"
descripcion = (
    "Ingresa una consulta y el modelo generará una respuesta en inglés. "
    "Luego, la respuesta se traducirá automáticamente al Español."
)

iface = gr.Interface(
    fn=generar_y_traducir_respuesta,
    inputs=[
        gr.Textbox(lines=5, placeholder="Escribe tu consulta aquí...", label="Consulta")
    ],
    outputs=[
        gr.Textbox(label="Respuesta en Inglés"),
        gr.Textbox(label="Respuesta Traducida al Español")
    ],
    title=titulo,
    description=descripcion,
    examples=[
        [
            "Clasifica el siguiente informe de radiología según la parte del cuerpo a la que se refiere (por ejemplo, pecho, abdomen, cerebro, etc.): Los discos intervertebrales en L4-L5 y L5-S1 muestran signos de degeneración con leve abultamiento que comprime la raíz nerviosa adyacente."
        ],
        [
            "Resume los puntos clave sobre el papel de las mutaciones en los genes BRCA1 y BRCA2 en el aumento del riesgo de cáncer de mama."
        ]
    ],
    cache_examples=False
)

# Ejecutar la interfaz
if __name__ == "__main__":
    iface.launch()