Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from minicons import cwe | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import pandas as pd | |
| from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams | |
| def predict (Sentence, Word, LLM, Norm, Layer): | |
| models = {'BERT': 'bert-base-uncased', | |
| 'ALBERT': 'albert-xxlarge-v2', | |
| 'RoBERTa': 'roberta-base'} | |
| if Word not in Sentence: return "invalid input: word not in sentence" | |
| model_name_hf = LLM.lower() | |
| norm_name_hf = Norm.lower() | |
| lm = cwe.CWE(models[LLM]) | |
| repo_id = "jwalanthi/semantic-feature-classifiers" | |
| subfolder = f"{model_name_hf}_models_all" | |
| name_hf = f"{model_name_hf}_to_{norm_name_hf}_layer{Layer}" | |
| model_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.ckpt", use_auth_token=os.environ['TOKEN']) | |
| label_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.txt", use_auth_token=os.environ['TOKEN']) | |
| model = FeatureNormPredictor.load_from_checkpoint( | |
| checkpoint_path=model_path, | |
| map_location=None | |
| ) | |
| model.eval() | |
| with open (label_path, "r") as file: | |
| labels = [line.rstrip() for line in file.readlines()] | |
| data = (Sentence, Word) | |
| emb = lm.extract_representation(data, layer=Layer) | |
| pred = torch.nn.functional.relu(model(emb)) | |
| pred_sq = pred.squeeze(0) | |
| pred_round = torch.round(pred_sq, decimals=2) | |
| pred_list = pred_round.detach().numpy().tolist() | |
| df = pd.DataFrame({'feature':labels, 'value':pred_list}) | |
| df = df[df['value'] > 0] | |
| df_sorted = df.sort_values(by='value', ascending=False) | |
| df_sorted = df_sorted.reset_index() | |
| Output = [row['feature']+'\t\t\t\t\t\t\t'+str(row['value']) for _, row in df_sorted.iterrows()] | |
| return "All Positive Predicted Values:\n"+"\n".join(Output) | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| "text", | |
| "text", | |
| gr.Radio(["BERT", "ALBERT", "RoBERTa"]), | |
| gr.Radio(["Binder", "McRae", "Buchanan"]), | |
| gr.Slider(0,12, step=1) | |
| ], | |
| outputs=["text"], | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| demo.launch() |