import json import logging from logging import Logger from logging.handlers import SysLogHandler import streamlit as st import tokenizers import torch from transformers import Pipeline, pipeline from utils import get_answer, get_context @st.cache( hash_funcs={ torch.nn.parameter.Parameter: lambda _: None, tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, }, allow_output_mutation=True, show_spinner=False, ) def load_engine() -> Pipeline: nlp_qa = pipeline( "question-answering", model="mrm8488/bert-italian-finedtuned-squadv1-it-alfa", tokenizer="mrm8488/bert-italian-finedtuned-squadv1-it-alfa", ) return nlp_qa syslog = SysLogHandler(address=(st.secrets["logging_address"], int(st.secrets["logging_port"]))) @st.cache(hash_funcs={SysLogHandler: id}) def load_logger(syslog: SysLogHandler) -> Logger: logger = logging.getLogger() logger.addHandler(syslog) logger.setLevel(logging.INFO) return logger with st.spinner( text="Sto preparando il necessario per rispondere alle tue domande personali..." ): engine = load_engine() logger = load_logger(syslog) st.title("Le risposte alle tue domande personali") input = st.text_input("Scrivi una domanda in italiano e comparirà la risposta!") if input: try: context = get_context() logger.info(input) answer = get_answer(input, context, engine) st.subheader(answer) except: st.error( "Qualcosa é andato storto. Prova di nuovo con un'altra domanda magari!" )