Spaces:
Runtime error
Runtime error
| import os | |
| os.chdir('naacl-2021-fudge-controlled-generation/') | |
| import gradio as gr | |
| from predict_clickbait import generate_clickbait | |
| from datasets import load_dataset,DatasetDict,Dataset | |
| # from datasets import | |
| from transformers import AutoTokenizer,AutoModelForSeq2SeqLM | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| import pandas as pd | |
| from sklearn.utils.class_weight import compute_class_weight | |
| import torch | |
| import pandas as pd | |
| from model import Model | |
| import imp | |
| import os | |
| import random | |
| import time | |
| import pickle | |
| import math | |
| from argparse import ArgumentParser | |
| from collections import namedtuple | |
| import mock | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from data import Dataset | |
| from util import save_checkpoint, ProgressMeter, AverageMeter, num_params | |
| from constants import * | |
| from predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer | |
| # imp.reload(model) | |
| pretrained_model = "../checkpoint-150/" | |
| generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device) | |
| device = 'cuda' | |
| pad_id = 0 | |
| generation_model.eval() | |
| model_args = mock.Mock() | |
| model_args.task = 'clickbait' | |
| model_args.device = device | |
| model_args.checkpoint = '../checkpoint-1464/' | |
| # conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
| conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
| conditioning_model = conditioning_model.to(device) | |
| conditioning_model.eval() | |
| condition_lambda = 5.0 | |
| length_cutoff = 50 | |
| precondition_topk = 200 | |
| conditioning_model.classifier | |
| model_args.checkpoint | |
| classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True) | |
| def rate_title(input_text, model, tokenizer, device='cuda'): | |
| # input_text = { | |
| # "postText": input_text['postText'], | |
| # "truthClass" : input_text['truthClass'] | |
| # } | |
| tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer) | |
| # print(tokenized_input.items()) | |
| dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'} | |
| predicted_class = float(model(**dict_tokenized_input).logits) | |
| actual_class = input_text['truthClass'] | |
| # print(predicted_class, actual_class) | |
| return {'predicted_class' : predicted_class} | |
| def preprocess_function_title_only_classification(examples,tokenizer=None): | |
| model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25) | |
| model_inputs['labels'] = examples['truthClass'] | |
| return model_inputs | |
| def clickbait_generator(article_content, condition_lambda=5.0): | |
| # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2)) | |
| results = generate_clickbait(model=generation_model, | |
| tokenizer=tokenizer, | |
| conditioning_model=conditioning_model, | |
| input_text=[None], | |
| dataset_info=dataset_info, | |
| precondition_topk=precondition_topk, | |
| length_cutoff=length_cutoff, | |
| condition_lambda=condition_lambda, | |
| article_content=article_content, | |
| device=device) | |
| return results[0].replace('</s>', '').replace('<pad>', '') | |
| title = "Clickbait generator" | |
| description = """ | |
| "Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!" | |
| """ | |
| article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of." | |
| app = gr.Interface( | |
| title = title, | |
| description = description, | |
| label = 'Article content or paragraph', | |
| fn = clickbait_generator, | |
| inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text") | |
| app.launch() |