Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
class Model(nn.Module): | |
def __init__(self, model_name='bert_model'): | |
super(Model, self).__init__() | |
self.bert = transformers.BertModel.from_pretrained(config['MODEL_ID'], return_dict=False) | |
self.bert_drop = nn.Dropout(0.0) | |
self.out = nn.Linear(config['HIDDEN_SIZE'], config['NUM_LABELS']) | |
self.model_name = model_name | |
def forward(self, ids, mask, token_type_ids): | |
_, o2 = self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids) | |
bo = self.bert_drop(o2) | |
output = self.out(bo) | |
return output | |
model = Model(model_name=este_si_me_sirvio.bin) | |
model.load_state_dict(torch.load(juanpasanper/tigo_question_answer)) | |
def question_answer(context, question): | |
predictions, raw_outputs = model.predict([{"context": context, "qas": [{"question": question, "id": "0",}],}]) | |
prediccion = predictions[0]['answer'][0] | |
return prediccion | |
iface = gr.Interface(fn=question_answer, inputs=["text", "text"], outputs=["text"], | |
allow_flagging="manual", | |
flagging_options=["correcto", "incorrecto"], | |
flagging_dir='flagged', | |
enable_queue = True) | |
iface.launch() |