Spaces:
Runtime error
Runtime error
| if __name__ == '__main__': | |
| inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd'] | |
| from transformers import AutoTokenizer | |
| from model import CustomModel | |
| import torch | |
| from configuration import CFG | |
| from dataset import SingleInputDataset | |
| from torch.utils.data import DataLoader | |
| from utils import inference_fn, get_char_probs, get_results, get_text | |
| import numpy as np | |
| import gradio as gr | |
| import os | |
| device = f'cuda' if torch.cuda.is_available() else 'cpu' | |
| config_path = os.path.join('models_file', 'config.pth') | |
| model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth') | |
| tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer') | |
| model = CustomModel(CFG, config_path=config_path, pretrained=False) | |
| state = torch.load(model_path, | |
| map_location=device) | |
| model.load_state_dict(state['model']) | |
| def get_answer(context, feature): | |
| ## Input to the model using patient-history and feature-text | |
| inputs_single = tokenizer(context, feature, | |
| add_special_tokens=True, | |
| max_length=CFG.max_len, | |
| padding="max_length", | |
| return_offsets_mapping=False) | |
| for k, v in inputs_single.items(): | |
| inputs_single[k] = torch.tensor(v, dtype=torch.long) | |
| # Create a new dataset containing only the input sample | |
| single_input_dataset = SingleInputDataset(inputs_single) | |
| # Create a DataLoader for the new dataset | |
| single_input_loader = DataLoader( | |
| single_input_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=2 | |
| ) | |
| # Perform inference on the single input | |
| output = inference_fn(single_input_loader, model, device) | |
| prediction = output.reshape((1, CFG.max_len)) | |
| char_probs = get_char_probs([context], prediction, tokenizer) | |
| predictions = np.mean([char_probs], axis=0) | |
| results = get_results(predictions, th=0.5) | |
| print(results) | |
| return get_text(context, results[0]) | |
| inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)] | |
| output = gr.outputs.Textbox(label="Answer") | |
| article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>" | |
| app = gr.Interface( | |
| fn=get_answer, | |
| inputs=inputs, | |
| outputs=output, | |
| allow_flagging='never', | |
| title="Phrase Extraction", | |
| article=article, | |
| enable_queue=True, | |
| cache_examples=False, | |
| css="footer {visibility: hidden}" | |
| ) | |
| app.launch() | |