BioRAG / app.py
C2MV's picture
Update app.py
c7756cc verified
raw
history blame
5.5 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, MarianMTModel, MarianTokenizer
import time
from functools import wraps
import sys
import spaces # Aseg煤rate de que este m贸dulo est茅 disponible y correctamente instalado
# Decorador para medir el tiempo de ejecuci贸n
def medir_tiempo(func):
@wraps(func)
def wrapper(*args, **kwargs):
inicio = time.time()
resultado = func(*args, **kwargs)
fin = time.time()
tiempo_transcurrido = fin - inicio
print(f"Tiempo de ejecuci贸n de '{func.__name__}': {tiempo_transcurrido:.2f} segundos")
return resultado
return wrapper
# Verificar si CUDA est谩 disponible para el modelo principal
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("Advertencia: CUDA no est谩 disponible. Se usar谩 la CPU, lo que puede ser lento.")
# Cargar el tokenizador y el modelo principal desde HuggingFace
model_name = "dmis-lab/selfbiorag_7b"
try:
print("Cargando el tokenizador y el modelo desde HuggingFace...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
except ValueError as e:
print(f"Error al cargar el tokenizador: {e}")
sys.exit(1)
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
except Exception as e:
print(f"Error al cargar el modelo: {e}")
sys.exit(1)
# Cargar el modelo de traducci贸n en CPU
try:
print("Cargando el tokenizador y el modelo de traducci贸n en CPU...")
translation_model_name = "Helsinki-NLP/opus-mt-en-es"
translator_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translator_model = MarianMTModel.from_pretrained(translation_model_name).to("cpu") # Forzar a CPU
except Exception as e:
print(f"Error al cargar el modelo de traducci贸n: {e}")
sys.exit(1)
@spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
@medir_tiempo
def generar_respuesta(consulta):
"""
Funci贸n que genera una respuesta a partir de una consulta dada.
"""
try:
if not consulta.strip():
return "Por favor, ingresa una consulta v谩lida."
# Tokenizar la consulta
inputs = tokenizer.encode(consulta, return_tensors="pt").to(device)
# Configurar los par谩metros de generaci贸n
generation_kwargs = {
"max_new_tokens": 100, # Ajustado a 100
"do_sample": False # No usar sampling
# "temperature": 0.6, # Eliminado para evitar advertencias
# "top_p": 0.9 # Eliminado para evitar advertencias
}
# Generar la respuesta
with torch.no_grad():
outputs = model.generate(input_ids=inputs, **generation_kwargs)
# Decodificar la respuesta
respuesta = tokenizer.decode(outputs[0], skip_special_tokens=True)
return respuesta
except Exception as e:
print(f"Error durante la generaci贸n de respuesta: {e}")
return f"Error al generar la respuesta: {e}"
def traducir_texto(texto):
"""
Funci贸n que traduce un texto de ingl茅s a espa帽ol.
"""
try:
if not texto.strip():
return "No hay texto para traducir."
# Tokenizar el texto a traducir
translated = translator_model.generate(**translator_tokenizer(texto, return_tensors="pt", padding=True))
# Decodificar la traducci贸n
traduccion = translator_tokenizer.decode(translated[0], skip_special_tokens=True)
return traduccion
except Exception as e:
print(f"Error durante la traducci贸n: {e}")
return f"Error al traducir el texto: {e}"
def procesar_consulta(consulta, idioma):
"""
Funci贸n que procesa la consulta y devuelve la respuesta original y/o traducida seg煤n el idioma seleccionado.
"""
respuesta_original = generar_respuesta(consulta)
if idioma == "Espa帽ol":
traduccion = traducir_texto(respuesta_original)
else:
traduccion = ""
return respuesta_original, traduccion
# Definir la interfaz de Gradio
titulo = "Generador de Respuestas con SelfBioRAG 7B"
descripcion = "Ingresa una consulta y selecciona el idioma de salida. El modelo generar谩 una respuesta basada en el contenido proporcionado."
iface = gr.Interface(
fn=procesar_consulta,
inputs=[
gr.Textbox(lines=5, placeholder="Escribe tu consulta aqu铆...", label="Consulta"),
gr.Dropdown(choices=["Ingl茅s", "Espa帽ol"], value="Espa帽ol", label="Idioma de Salida")
],
outputs=[
gr.Textbox(label="Respuesta Original (Ingl茅s)"),
gr.Textbox(label="Traducci贸n al Espa帽ol")
],
title=titulo,
description=descripcion,
examples=[
[
"Clasifica el siguiente informe de radiolog铆a seg煤n la parte del cuerpo a la que se refiere (por ejemplo, pecho, abdomen, cerebro, etc.): Los discos intervertebrales en L4-L5 y L5-S1 muestran signos de degeneraci贸n con leve abultamiento que comprime la ra铆z nerviosa adyacente."
],
[
"Resume los puntos clave sobre el papel de las mutaciones en los genes BRCA1 y BRCA2 en el aumento del riesgo de c谩ncer de mama."
]
],
cache_examples=False,
allow_flagging="never"
)
# Ejecutar la interfaz
if __name__ == "__main__":
iface.launch()