import torch import tensorflow as tf from tf_keras import models, layers from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering, AutoModelForCausalLM import gradio as gr import re import os # Check if GPU is available and use it if possible device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Version Information: confli_version_spanish = 'ConfliBERT-Spanish-Beto-Cased-NewsQA' beto_version_spanish = 'Beto-Spanish-Cased-NewsQA' gpt2_spanish_version = 'GPT-2-Small-Spanish' bloom_spanish_version = 'BLOOM-1.7B' beto_sqac_version_spanish = 'Beto-Spanish-Cased-SQAC' # Load Spanish models and tokenizers confli_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' confli_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish) confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish) beto_model_spanish = 'salsarra/Beto-Spanish-Cased-NewsQA' beto_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish) beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish) beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC' beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish) beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish) # Load Spanish GPT-2 model and tokenizer gpt2_spanish_model_name = 'datificate/gpt2-small-spanish' gpt2_spanish_tokenizer = AutoTokenizer.from_pretrained(gpt2_spanish_model_name) gpt2_spanish_model = AutoModelForCausalLM.from_pretrained(gpt2_spanish_model_name).to(device) # Load BLOOM-1.7B model and tokenizer for Spanish bloom_model_name = 'bigscience/bloom-1b7' bloom_tokenizer = AutoTokenizer.from_pretrained(bloom_model_name) bloom_model = AutoModelForCausalLM.from_pretrained(bloom_model_name).to(device) def handle_error_message(e, default_limit=512): error_message = str(e) pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)") match = pattern.search(error_message) if match: number_1, number_2 = match.groups() return f"Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}" pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)") match_qa = pattern_qa.search(error_message) if match_qa: number_1, number_2 = match_qa.groups() return f"Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}" return f"Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}" # Spanish QA functions def question_answering_spanish(context, question): try: inputs = confli_tokenizer_spanish(question, context, return_tensors='tf', truncation=True) outputs = confli_model_spanish_qa(inputs) answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 answer = confli_tokenizer_spanish.convert_tokens_to_string(confli_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) return f"{answer}" except Exception as e: return handle_error_message(e) def beto_question_answering_spanish(context, question): try: inputs = beto_tokenizer_spanish(question, context, return_tensors='tf', truncation=True) outputs = beto_model_spanish_qa(inputs) answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 answer = beto_tokenizer_spanish.convert_tokens_to_string(beto_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) return f"{answer}" except Exception as e: return handle_error_message(e) def beto_sqac_question_answering_spanish(context, question): try: inputs = beto_sqac_tokenizer_spanish(question, context, return_tensors='tf', truncation=True) outputs = beto_sqac_model_spanish_qa(inputs) answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 answer = beto_sqac_tokenizer_spanish.convert_tokens_to_string(beto_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) return f"{answer}" except Exception as e: return handle_error_message(e) # Functions for Spanish GPT-2 and BLOOM-1.7B models def gpt2_spanish_question_answering(context, question): try: prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:" inputs = gpt2_spanish_tokenizer(prompt, return_tensors='pt').to(device) outputs = gpt2_spanish_model.generate( inputs['input_ids'], max_length=inputs['input_ids'].shape[1] + 50, num_return_sequences=1, pad_token_id=gpt2_spanish_tokenizer.eos_token_id, do_sample=True, top_k=40, temperature=0.8 ) answer = gpt2_spanish_tokenizer.decode(outputs[0], skip_special_tokens=True) answer = answer.split("Respuesta:")[-1].strip() return f"{answer}" except Exception as e: return handle_error_message(e) def bloom_question_answering(context, question): try: prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:" inputs = bloom_tokenizer(prompt, return_tensors='pt').to(device) outputs = bloom_model.generate( inputs['input_ids'], max_length=inputs['input_ids'].shape[1] + 50, num_return_sequences=1, pad_token_id=bloom_tokenizer.eos_token_id, do_sample=True, top_k=40, temperature=0.8 ) answer = bloom_tokenizer.decode(outputs[0], skip_special_tokens=True) answer = answer.split("Respuesta:")[-1].strip() return f"{answer}" except Exception as e: return handle_error_message(e) # Main function for Spanish QA def compare_question_answering_spanish(context, question): confli_answer_spanish = question_answering_spanish(context, question) beto_answer_spanish = beto_question_answering_spanish(context, question) beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question) gpt2_answer_spanish = gpt2_spanish_question_answering(context, question) bloom_answer = bloom_question_answering(context, question) return f"""

Respuestas:


ConfliBERT-Spanish-Beto-Cased-NewsQA:
{confli_answer_spanish}

Beto-Spanish-Cased-NewsQA:
{beto_answer_spanish}

Beto-Spanish-Cased-SQAC:
{beto_sqac_answer_spanish}

GPT-2-Small-Spanish:
{gpt2_answer_spanish}

BLOOM-1.7B:
{bloom_answer}

Información del modelo:
ConfliBERT-Spanish-Beto-Cased-NewsQA: salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA
Beto-Spanish-Cased-NewsQA: salsarra/Beto-Spanish-Cased-NewsQA
Beto-Spanish-Cased-SQAC: salsarra/Beto-Spanish-Cased-SQAC
GPT-2-Small-Spanish: datificate GPT-2 Small Spanish
BLOOM-1.7B: bigscience BLOOM-1.7B
""" # Define the CSS for Gradio interface css_styles = """ body { background-color: #f0f8ff; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; } h1 a { color: #2e8b57; text-align: center; font-size: 2em; text-decoration: none; } h1 a:hover { color: #ff8c00; } h2 { color: #ff8c00; text-align: center; font-size: 1.5em; } .gradio-container { max-width: 100%; margin: 10px auto; padding: 10px; background-color: #ffffff; border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); } .gr-input, .gr-output { background-color: #ffffff; border: 1px solid #ddd; border-radius: 5px; padding: 10px; font-size: 1em; } .gr-title { font-size: 1.5em; font-weight: bold; color: #2e8b57; margin-bottom: 10px; text-align: center; } .gr-description { font-size: 1.2em; color: #ff8c00; margin-bottom: 10px; text-align: center. } .header-title-center a { font-size: 4em; font-weight: bold; color: darkorange; text-align: center; display: block. } .gr-button { background-color: #ff8c00; color: white; border: none; padding: 10px 20px; font-size: 1em. border-radius: 5px; cursor: pointer. } .gr-button:hover { background-color: #ff4500. } .footer { text-align: center. margin-top: 10px. font-size: 0.9em. color: #666. width: 100%. } .footer a { color: #2e8b57. font-weight: bold. text-decoration: none. } .footer a:hover { text-decoration: underline. } """ # Define the Gradio interface demo = gr.Interface( fn=compare_question_answering_spanish, inputs=[ gr.Textbox(lines=5, placeholder="Ingrese el contexto aquí...", label="Contexto"), gr.Textbox(lines=2, placeholder="Ingrese su pregunta aquí...", label="Pregunta") ], outputs=gr.HTML(label="Salida"), title="ConfliBERT-Spanish-QA", description="Compare respuestas entre los modelos ConfliBERT, BETO, Beto SQAC, GPT-2 Small Spanish y BLOOM-1.7B para preguntas en español.", css=css_styles, allow_flagging="never" ) # Launch the Gradio demo demo.launch(share=True)