Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import sentencepiece | |
| # ๋ชจ๋ธ ์ค๋นํ๊ธฐ | |
| from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer | |
| from torch.utils.data import DataLoader, Dataset | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import os | |
| # [theme] | |
| # base="dark" | |
| # primaryColor="purple" | |
| # ์ ๋ชฉ ์ ๋ ฅ | |
| st.header('ํ๊ตญํ์ค์ฐ์ ๋ถ๋ฅ ์๋์ฝ๋ฉ ์๋น์ค') | |
| # ์ฌ๋ก๋ ์ํ๋๋ก | |
| def md_loading(): | |
| ## cpu | |
| # device = torch.device('cpu') | |
| tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') | |
| model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-large', num_labels=493) | |
| model_checkpoint = 'base3_44_last.bin' | |
| project_path = './' | |
| output_model_file = os.path.join(project_path, model_checkpoint) | |
| model.load_state_dict(torch.load(output_model_file)) | |
| # ckpt = torch.load(output_model_file) | |
| # model.load_state_dict(ckpt['model_state_dict']) | |
| # device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu") | |
| device = torch.device("cpu") | |
| model.to(device) | |
| label_tbl = np.load('./label_table.npy') | |
| loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8') | |
| print('ready') | |
| return tokenizer, model, label_tbl, loc_tbl, device | |
| # ๋ชจ๋ธ ๋ก๋ | |
| tokenizer, model, label_tbl, loc_tbl, device = md_loading() | |
| # ๋ฐ์ดํฐ ์ ์ค๋น์ฉ | |
| max_len = 64 # 64 | |
| class TVT_Dataset(Dataset): | |
| def __init__(self, df): | |
| self.df_data = df | |
| def __getitem__(self, index): | |
| # ๋ฐ์ดํฐํ๋ ์ ์นผ๋ผ ๋ค๊ณ ์ค๊ธฐ | |
| # sentence = self.df_data.loc[index, 'text'] | |
| sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']] | |
| encoded_dict = tokenizer( | |
| ' <s> '.join(sentence.to_list()), | |
| add_special_tokens = True, | |
| max_length = max_len, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask = True, | |
| return_tensors = 'pt') | |
| padded_token_list = encoded_dict['input_ids'][0] | |
| att_mask = encoded_dict['attention_mask'][0] | |
| # ์ซ์๋ก ๋ณํ๋ label์ ํ ์๋ก ๋ณํ | |
| # target = torch.tensor(self.df_data.loc[index, 'NEW_CD']) | |
| # input_ids, attention_mask, label์ ํ๋์ ์ธํ์ผ๋ก ๋ฌถ์ | |
| # sample = (padded_token_list, att_mask, target) | |
| sample = (padded_token_list, att_mask) | |
| return sample | |
| def __len__(self): | |
| return len(self.df_data) | |
| # ํ ์คํธ input ๋ฐ์ค | |
| business = st.text_input('์ฌ์ ์ฒด๋ช ') | |
| business_work = st.text_input('์ฌ์ ์ฒด ํ๋์ผ') | |
| work_department = st.text_input('๊ทผ๋ฌด๋ถ์') | |
| work_position = st.text_input('์ง์ฑ ') | |
| what_do_i = st.text_input('๋ด๊ฐ ํ๋ ์ผ') | |
| # data ์ค๋น | |
| # test dataset์ ๋ง๋ค์ด์ค๋๋ค. | |
| input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM', 'NEW_CD'] | |
| def preprocess_dataset(dataset): | |
| dataset.reset_index(drop=True, inplace=True) | |
| dataset.fillna('') | |
| return dataset[input_col_type] | |
| ## ์์ ํ์ธ | |
| # st.write(md_input) | |
| # ๋ฒํผ | |
| if st.button('ํ์ธ'): | |
| ## ๋ฒํผ ํด๋ฆญ ์ ์ํ์ฌํญ | |
| ### ๋ฐ์ดํฐ ์ค๋น | |
| # md_input: ๋ชจ๋ธ์ ์ ๋ ฅํ input ๊ฐ ์ ์ | |
| # md_input = '|'.join([business, business_work, what_do_i, work_position, work_department]) | |
| md_input = [str(business), str(business_work), str(what_do_i), str(work_position), str(work_department)] | |
| test_dataset = pd.DataFrame({ | |
| input_col_type[0]: md_input[0], | |
| input_col_type[1]: md_input[1], | |
| input_col_type[2]: md_input[2], | |
| input_col_type[3]: md_input[3], | |
| input_col_type[4]: md_input[4] | |
| }) | |
| # test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False) | |
| test_dataset = preprocess_dataset(test_dataset) | |
| print(len(test_dataset)) | |
| print(test_dataset) | |
| print('base_data_loader ์ฌ์ฉ ์์ ์ ') | |
| test_data = TVT_Dataset(test_dataset) | |
| train_batch_size = 48 | |
| # batch_size ๋งํผ ๋ฐ์ดํฐ ๋ถํ | |
| test_dataloader = DataLoader(test_data, | |
| batch_size=train_batch_size, | |
| shuffle=False) | |
| ### ๋ชจ๋ธ ์คํ | |
| # Put model in evaluation mode | |
| model.eval() | |
| model.zero_grad() | |
| # Tracking variables | |
| predictions , true_labels = [], [] | |
| # Predict | |
| for batch in range(test_dataloader): | |
| # Add batch to GPU | |
| batch = tuple(t.to(device) for t in batch) | |
| # Unpack the inputs from our dataloader | |
| test_input_ids, test_attention_mask = batch | |
| # Telling the model not to compute or store gradients, saving memory and | |
| # speeding up prediction | |
| with torch.no_grad(): | |
| # Forward pass, calculate logit predictions | |
| outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask) | |
| logits = outputs.logits | |
| # Move logits and labels to CPU | |
| # logits = logits.detach().cpu().numpy() | |
| # # ๋จ๋ ์์ธก ์ | |
| # arg_idx = torch.argmax(logits, dim=1) | |
| # print('arg_idx:', arg_idx) | |
| # num_ans = label_tbl[arg_idx] | |
| # str_ans = loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == num_ans].values | |
| # ์์ k๋ฒ์งธ๊น์ง ์์ธก ์ | |
| k = 10 | |
| topk_idx = torch.topk(logits.flatten(), k).indices | |
| num_ans_topk = label_tbl[topk_idx] | |
| str_ans_topk = [loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == k] for k in num_ans_topk] | |
| # print(num_ans, str_ans) | |
| # print(num_ans_topk) | |
| # print('์ฌ์ ์ฒด๋ช :', query_tokens[0]) | |
| # print('์ฌ์ ์ฒด ํ๋์ผ:', query_tokens[1]) | |
| # print('๊ทผ๋ฌด๋ถ์:', query_tokens[2]) | |
| # print('์ง์ฑ :', query_tokens[3]) | |
| # print('๋ด๊ฐ ํ๋์ผ:', query_tokens[4]) | |
| # print('์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:', num_ans, str_ans) | |
| # ans = '' | |
| # ans1, ans2, ans3 = '', '', '' | |
| ## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ | |
| # st.write("์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:", num_ans, str_ans[0]) | |
| # st.write("์ธ๋ถ๋ฅ ์ฝ๋") | |
| # for i in range(k): | |
| # st.write(str(i+1) + '์์:', num_ans_topk[i], str_ans_topk[i].iloc[0]) | |
| # print(num_ans) | |
| # print(str_ans, type(str_ans)) | |
| str_ans_topk_list = [] | |
| for i in range(k): | |
| str_ans_topk_list.append(str_ans_topk[i].iloc[0]) | |
| # print(str_ans_topk_list) | |
| ans_topk_df = pd.DataFrame({ | |
| 'NO': range(1, k+1), | |
| '์ธ๋ถ๋ฅ ์ฝ๋': num_ans_topk, | |
| '์ธ๋ถ๋ฅ ๋ช ์นญ': str_ans_topk_list | |
| }) | |
| ans_topk_df = ans_topk_df.set_index('NO') | |
| st.dataframe(ans_topk_df) |