Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| #### env base_cp | |
| #main_path = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main" | |
| #main_path = "/User/tableQA-Chinese-main" | |
| #main_path = "/temp/tableQA-Chinese-main" | |
| main_path = "." | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import ast | |
| import re | |
| import json | |
| from icecream import ic | |
| from copy import deepcopy | |
| from itertools import product, combinations | |
| import pandas as pd | |
| import os | |
| import sys | |
| from pyarrow.filesystem import LocalFileSystem | |
| from functools import reduce | |
| import nltk | |
| from nltk import pos_tag, word_tokenize | |
| from collections import namedtuple | |
| from ast import literal_eval | |
| from torch.nn import functional | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import init | |
| from torch.nn.utils import rnn as rnn_utils | |
| import math | |
| from icecream import ic | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import shutil | |
| #from keybert import KeyBERT | |
| #from bertopic import BERTopic | |
| import sqlite3 | |
| import sqlite_utils | |
| from icecream import ic | |
| import jieba | |
| import pandas as pd | |
| import urllib.request | |
| from urllib.parse import quote | |
| from time import sleep | |
| import json | |
| import os | |
| from collections import defaultdict | |
| import re | |
| from functools import reduce, partial | |
| #### used in this condition extract in training. | |
| op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="} | |
| #### used by clf for intension inference | |
| agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"} | |
| #### final to combine them (one for 0, and multi for 1 2) | |
| conn_sql_dict = {0:"", 1:"and", 2:"or"} | |
| #### kws and time pattern defination | |
| and_kws = ("且", "而且", "并且", "和", "当中", "同时") | |
| or_kws = ("或", "或者",) | |
| conn_kws = and_kws + or_kws | |
| pattern_list = [u"[年月\.\-\d]+", u"[年月\d]+", u"[年个月\d]+", u"[年月日\d]+"] | |
| time_kws = ("什么时候", "时间", "时候") | |
| sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",) | |
| mean_kws = ('平均数', '均值', '平均值', '平均') | |
| max_kws = ('最大', '最多', '最大值', '最高') | |
| min_kws = ('最少', '最小值', '最小', '最低') | |
| sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",) | |
| max_special_kws = ("以上", "大于") | |
| min_special_kws = ("以下", "小于") | |
| qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个") | |
| only_kws_columns = {"城市": "=="} | |
| ##### jointbert predict model init start | |
| #jointbert_path = "../../featurize/JointBERT" | |
| #jointbert_path = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main/JointBERT-master" | |
| jointbert_path = os.path.join(main_path, "JointBERT-master") | |
| sys.path.append(jointbert_path) | |
| from model.modeling_jointbert import JointBERT | |
| from model.modeling_jointbert import * | |
| from trainer import * | |
| from main import * | |
| from data_loader import * | |
| pred_parser = argparse.ArgumentParser() | |
| pred_parser.add_argument("--input_file", default="conds_pred/seq.in", type=str, help="Input file for prediction") | |
| pred_parser.add_argument("--output_file", default="conds_pred/sample_pred_out.txt", type=str, help="Output file for prediction") | |
| #pred_parser.add_argument("--model_dir", default="bert", type=str, help="Path to save, load model") | |
| pred_parser.add_argument("--model_dir", default= os.path.join(main_path ,"data/bert"), type=str, help="Path to save, load model") | |
| #pred_parser.add_argument("--model_dir", default= os.path.join(main_path ,"JBert_Zh_Condition_Extractor"), type=str, help="Path to save, load model") | |
| pred_parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction") | |
| pred_parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") | |
| pred_parser_config_dict = dict(map(lambda item:(item.option_strings[0].replace("--", ""), item.default) ,pred_parser.__dict__["_actions"])) | |
| pred_parser_config_dict = dict(filter(lambda t2: t2[0] != "-h", pred_parser_config_dict.items())) | |
| pred_parser_namedtuple = namedtuple("pred_parser_config", pred_parser_config_dict.keys()) | |
| for k, v in pred_parser_config_dict.items(): | |
| if type(v) == type(""): | |
| exec("pred_parser_namedtuple.{}='{}'".format(k, v)) | |
| else: | |
| exec("pred_parser_namedtuple.{}={}".format(k, v)) | |
| from predict import * | |
| pred_config = pred_parser_namedtuple | |
| args = get_args(pred_config) | |
| device = get_device(pred_config) | |
| args_parser_namedtuple = namedtuple("args_config", args.keys()) | |
| for k, v in args.items(): | |
| if type(v) == type(""): | |
| exec("args_parser_namedtuple.{}='{}'".format(k, v)) | |
| else: | |
| exec("args_parser_namedtuple.{}={}".format(k, v)) | |
| args = args_parser_namedtuple | |
| #args.data_dir = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main/data" | |
| args.data_dir = os.path.join(main_path, "data") | |
| ''' | |
| pred_model = MODEL_CLASSES["bert"][1].from_pretrained(args.model_dir, | |
| args=args, | |
| intent_label_lst=get_intent_labels(args), | |
| slot_label_lst=get_slot_labels(args)) | |
| ''' | |
| pred_model = MODEL_CLASSES["bert"][1].from_pretrained( | |
| os.path.join(main_path, "data/bert") | |
| , | |
| args=args, | |
| intent_label_lst=get_intent_labels(args), | |
| slot_label_lst=get_slot_labels(args)) | |
| pred_model.to(device) | |
| pred_model.eval() | |
| intent_label_lst = get_intent_labels(args) | |
| slot_label_lst = get_slot_labels(args) | |
| pad_token_label_id = args.ignore_index | |
| tokenizer = load_tokenizer(args) | |
| ## jointbert predict model init end | |
| ###### one sent conds decomp start | |
| def predict_single_sent(question): | |
| text = " ".join(list(question)) | |
| batch = convert_input_file_to_tensor_dataset([text.split(" ")], pred_config, args, tokenizer, pad_token_label_id).tensors | |
| batch = tuple(t.to(device) for t in batch) | |
| inputs = {"input_ids": batch[0], | |
| "attention_mask": batch[1], | |
| "intent_label_ids": None, | |
| "slot_labels_ids": None} | |
| inputs["token_type_ids"] = batch[2] | |
| outputs = pred_model(**inputs) | |
| _, (intent_logits, slot_logits) = outputs[:2] | |
| intent_preds = intent_logits.detach().cpu().numpy() | |
| slot_preds = slot_logits.detach().cpu().numpy() | |
| intent_preds = np.argmax(intent_preds, axis=1) | |
| slot_preds = np.argmax(slot_preds, axis=2) | |
| all_slot_label_mask = batch[3].detach().cpu().numpy() | |
| slot_label_map = {i: label for i, label in enumerate(slot_label_lst)} | |
| slot_preds_list = [[] for _ in range(slot_preds.shape[0])] | |
| for i in range(slot_preds.shape[0]): | |
| for j in range(slot_preds.shape[1]): | |
| if all_slot_label_mask[i, j] != pad_token_label_id: | |
| slot_preds_list[i].append(slot_label_map[slot_preds[i][j]]) | |
| pred_l = [] | |
| for words, slot_preds, intent_pred in zip([text.split(" ")], slot_preds_list, intent_preds): | |
| line = "" | |
| for word, pred in zip(words, slot_preds): | |
| if pred == 'O': | |
| line = line + word + " " | |
| else: | |
| line = line + "[{}:{}] ".format(word, pred) | |
| pred_l.append((line, intent_label_lst[intent_pred])) | |
| return pred_l[0] | |
| ###@@ conn_kws = ["且", "或", "或者", "和"] | |
| ''' | |
| and_kws = ("且", "而且", "并且", "和", "当中", "同时") | |
| or_kws = ("或", "或者",) | |
| conn_kws = and_kws + or_kws | |
| ''' | |
| #conn_kws = ("且", "或", "或者", "和") + ("而且", "并且", "当中") | |
| #### some algorithm use in it. | |
| def recurrent_extract(question): | |
| def filter_relation(text): | |
| #kws = ["且", "或", "或者", "和"] | |
| kws = conn_kws | |
| req = text | |
| for kw in sorted(kws, key= lambda x: len(x))[::-1]: | |
| req = req.replace(kw, "") | |
| return req | |
| def produce_plain_text(text): | |
| ##### replace tag string from text | |
| kws = ["[", "]", " ", ":B-HEADER", ":I-HEADER", ":B-VALUE", ":I-VALUE"] | |
| plain_text = text | |
| for kw in kws: | |
| plain_text = plain_text.replace(kw, "") | |
| return plain_text | |
| def find_min_commmon_strings(c): | |
| ##### {"jack", "ja", "ss", "sss", "ps", ""} -> {"ja", "ss", "ps"} | |
| common_strings = list(filter(lambda x: type(x) == type("") , | |
| map(lambda t2: t2[0] | |
| if t2[0] in t2[1] | |
| else (t2[1] | |
| if t2[1] in t2[0] | |
| else (t2[0], t2[1])),combinations(c, 2)))) | |
| req = set([]) | |
| while c: | |
| ele = c.pop() | |
| if all(map(lambda cc: cc not in ele, common_strings)): | |
| req.add(ele) | |
| req = req.union(set(common_strings)) | |
| return set(filter(lambda x: x, req)) | |
| def extract_scope(scope_text): | |
| def find_max_in(plain_text ,b_chars, i_chars): | |
| chars = "".join(b_chars + i_chars) | |
| while chars and chars not in plain_text: | |
| chars = chars[:-1] | |
| return chars | |
| b_header_chars = re.findall(r"([\w\W]):B-HEADER", scope_text) | |
| i_header_chars = re.findall(r"([\w\W]):I-HEADER", scope_text) | |
| b_value_chars = re.findall(r"([\w\W]):B-VALUE", scope_text) | |
| i_value_chars = re.findall(r"([\w\W]):I-VALUE", scope_text) | |
| if len(b_header_chars) != 1 or len(b_value_chars) != 1: | |
| return None | |
| plain_text = produce_plain_text(scope_text) | |
| header = find_max_in(plain_text, b_header_chars, i_header_chars) | |
| value = find_max_in(plain_text, b_value_chars, i_value_chars) | |
| if (not header) or (not value): | |
| return None | |
| return (header, value) | |
| def find_scope(text): | |
| start_index = text.find("[") | |
| end_index = text.rfind("]") | |
| if start_index == -1 or end_index == -1: | |
| return text | |
| scope_text = text[start_index: end_index + 1] | |
| res_text = filter_relation(text.replace(scope_text, "")).replace(" ", "").strip() | |
| return (scope_text, res_text) | |
| def produce_all_attribute_remove(req): | |
| if not req: | |
| return None | |
| string_or_t2 = find_scope(req[-1][0]) | |
| assert type(string_or_t2) in [type(""), type((1,))] | |
| if type(string_or_t2) == type(""): | |
| return string_or_t2 | |
| else: | |
| return string_or_t2[-1] | |
| def extract_all_attribute(req): | |
| extract_list = list(map(lambda t2: (t2[0][0], t2[1], t2[0][1]) , | |
| filter(lambda x: x[0] , | |
| map(lambda tt2_t2: (extract_scope(tt2_t2[0][0]), tt2_t2[1]) , | |
| filter(lambda t2_t2: "HEADER" in t2_t2[0][0] and "VALUE" in t2_t2[0][0] , | |
| filter(lambda string_or_t2_t2: type(string_or_t2_t2[0]) == type((1,)), | |
| map(lambda tttt2: (find_scope(tttt2[0]), tttt2[1]), | |
| req))))))) | |
| return extract_list | |
| def extract_attributes_relation_string(plain_text, all_attributes, res): | |
| if not all_attributes: | |
| return plain_text.replace(res if res else "", "") | |
| def replace_by_one_l_r(text ,t3): | |
| l, _, r = t3 | |
| ##### produce multi l, r to satisfy string contrain problem | |
| l0, l1 = l, l | |
| r0, r1 = r, r | |
| while l0 and l0 not in text: | |
| l0 = l0[:-1] | |
| while l1 and l1 not in text: | |
| l1 = l1[1:] | |
| while r0 and r0 not in text: | |
| r0 = r0[:-1] | |
| while r1 and r1 not in text: | |
| r1 = r1[1:] | |
| if not l or not r: | |
| return text | |
| conclusion = set([]) | |
| for l_, r_ in product([l0, l1], [r0, r1]): | |
| l_r_conclusion = re.findall("({}.*?{})".format(l_, r_), text) | |
| r_l_conclusion = re.findall("({}.*?{})".format(r_, l_), text) | |
| conclusion = conclusion.union(set(l_r_conclusion + r_l_conclusion)) | |
| ##### because use produce multi must choose the shortest elements from them | |
| ## to prevent "relation word" also be replaced. | |
| conclusion_filtered = find_min_commmon_strings(conclusion) | |
| conclusion = conclusion_filtered | |
| req_text = text | |
| for c in conclusion: | |
| req_text = req_text.replace(c, "") | |
| return req_text | |
| req_text_ = plain_text | |
| for t3 in all_attributes: | |
| req_text_ = replace_by_one_l_r(req_text_, t3) | |
| return req_text_.replace(res, "") | |
| req = [] | |
| t2 = predict_single_sent(question) | |
| req.append(t2) | |
| while "[" in t2[0]: | |
| scope = find_scope(t2[0]) | |
| if type(scope) == type(""): | |
| break | |
| else: | |
| assert type(scope) == type((1,)) | |
| scope_text, res_text = scope | |
| #ic(req) | |
| t2 = predict_single_sent(res_text) | |
| req.append(t2) | |
| req = list(filter(lambda tt2: "HEADER" in tt2[0] and "VALUE" in tt2[0] , req)) | |
| res = produce_all_attribute_remove(req) | |
| #ic(req) | |
| all_attributes = extract_all_attribute(req) | |
| # plain_text = produce_plain_text(scope_text) | |
| return all_attributes, res, extract_attributes_relation_string(produce_plain_text(req[0][0] if req else ""), all_attributes, res) | |
| def rec_more_time(decomp): | |
| assert type(decomp) == type((1,)) and len(decomp) == 3 | |
| assert not decomp[0] | |
| res, relation_string = decomp[1:] | |
| new_decomp = recurrent_extract(relation_string) | |
| #### stop if rec not help by new_decomp[1] != decomp[1] | |
| if not new_decomp[0] and new_decomp[1] != decomp[1]: | |
| return rec_more_time(new_decomp) | |
| return (new_decomp[0], res, new_decomp[1]) | |
| ### one sent conds decomp end | |
| ##### data source start | |
| #train_path = "../TableQA/TableQA/train" | |
| #train_path = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main/data/TableQA-master/train" | |
| train_path = os.path.join(main_path, "data/TableQA-master/train") | |
| def data_loader(table_json_path = os.path.join(train_path ,"train.tables.json"), | |
| json_path = os.path.join(train_path ,"train.json"), | |
| req_table_num = 1): | |
| assert os.path.exists(table_json_path) | |
| assert os.path.exists(json_path) | |
| json_df = pd.read_json(json_path, lines = True) | |
| all_tables = pd.read_json(table_json_path, lines = True) | |
| if req_table_num is not None: | |
| assert type(req_table_num) == type(0) and req_table_num > 0 and req_table_num <= all_tables.shape[0] | |
| else: | |
| req_table_num = all_tables.shape[0] | |
| for i in range(req_table_num): | |
| #one_table = all_tables.iloc[i]["table"] | |
| #one_table_df = pd.read_sql("select * from `{}`".format(one_table), train_tables_dump_engine) | |
| one_table_s = all_tables.iloc[i] | |
| one_table_df = pd.DataFrame(one_table_s["rows"], columns = one_table_s["header"]) | |
| yield one_table_df, json_df[json_df["table_id"] == one_table_s["id"]] | |
| ## data source end | |
| ###### string toolkit start | |
| def findMaxSubString(str1, str2): | |
| """ | |
| """ | |
| maxSub = 0 | |
| maxSubString = "" | |
| str1_len = len(str1) | |
| str2_len = len(str2) | |
| for i in range(str1_len): | |
| str1_pos = i | |
| for j in range(str2_len): | |
| str2_pos = j | |
| str1_pos = i | |
| if str1[str1_pos] != str2[str2_pos]: | |
| continue | |
| else: | |
| while (str1_pos < str1_len) and (str2_pos < str2_len): | |
| if str1[str1_pos] == str2[str2_pos]: | |
| str1_pos = str1_pos + 1 | |
| str2_pos = str2_pos + 1 | |
| else: | |
| break | |
| sub_len = str2_pos - j | |
| if maxSub < sub_len: | |
| maxSub = sub_len | |
| maxSubString = str2[j:str2_pos] | |
| return maxSubString | |
| def find_min_commmon_strings(c): | |
| ##### {"jack", "ja", "ss", "sss", "ps", ""} -> {"ja", "ss", "ps"} | |
| common_strings = list(filter(lambda x: type(x) == type("") , | |
| map(lambda t2: t2[0] | |
| if t2[0] in t2[1] | |
| else (t2[1] | |
| if t2[1] in t2[0] | |
| else (t2[0], t2[1])),combinations(c, 2)))) | |
| req = set([]) | |
| while c: | |
| ele = c.pop() | |
| if all(map(lambda cc: cc not in ele, common_strings)): | |
| req.add(ele) | |
| req = req.union(set(common_strings)) | |
| return set(filter(lambda x: x, req)) | |
| ## string toolkit end | |
| ###### datetime column match start | |
| #### only use object dtype to extract | |
| def time_template_extractor(rows_filtered, pattern = u"[年月\.\-\d]+"): | |
| #re_words = re.compile(u"[年月\.\-\d]+") | |
| re_words = re.compile(pattern) | |
| nest_collection = pd.DataFrame(rows_filtered).applymap(lambda x: tuple(sorted(list(re.findall(re_words, x))))).values.tolist() | |
| def flatten_collection(c): | |
| if not c: | |
| return c | |
| if type(c[0]) == type(""): | |
| return c | |
| else: | |
| c = list(c) | |
| return flatten_collection(reduce(lambda a, b: a + b, map(list ,c))) | |
| return flatten_collection(nest_collection) | |
| ###@@ pattern_list | |
| #pattern_list = [u"[年月\.\-\d]+", u"[年月\d]+", u"[年个月\d]+", u"[年月日\d]+"] | |
| def justify_column_as_datetime(df, threshold = 0.8, time_template_extractor = lambda x: x): | |
| object_columns = list(map(lambda tt2: tt2[0] ,filter(lambda t2: t2[1].name == "object" ,dict(df.dtypes).items()))) | |
| time_columns = [] | |
| for col in object_columns: | |
| input_ = df[[col]].applymap(lambda x: "~" if type(x) != type("") else x) | |
| output_ = time_template_extractor(input_.values.tolist()) | |
| input_ = input_.iloc[:, 0].values.tolist() | |
| time_evidence_cnt = sum(map(lambda t2: t2[0].strip() == t2[1].strip() and t2[0] and t2[1] and t2[0] != "~" and t2[1] != "~",zip(input_, output_))) | |
| if time_evidence_cnt > 0 and time_evidence_cnt / df.shape[0] >= threshold: | |
| #### use evidence ratio because may have some noise in data | |
| time_columns.append(col) | |
| return time_columns | |
| def justify_column_as_datetime_reduce(df, threshold = 0.8, time_template_extractor_list = list(map(lambda p: partial(time_template_extractor, pattern = p), pattern_list))): | |
| return sorted(reduce(lambda a, b: a.union(b) ,map(lambda func: set(justify_column_as_datetime(df, threshold, func)), time_template_extractor_list))) | |
| ## datetime column match end | |
| ##### choose question column have a reduce function call below (choose_res_by_kws) | |
| ##### this is a tiny first version | |
| ###@@ time_kws = ("什么时候", "时间", "时候") | |
| #time_kws = ("什么时候", "时间", "时候") | |
| ##### | |
| def choose_question_column(decomp, header, df): | |
| assert type(decomp) == type((1,)) and type(header) == type([]) | |
| time_columns = justify_column_as_datetime_reduce(df) | |
| _, res, _ = decomp | |
| if type(res) != type(""): | |
| return None | |
| #ic(res) | |
| ##### should add time kws to it. | |
| #time_kws = ("什么时候", "时间", "时候") | |
| if any(map(lambda t_kw: t_kw in res, time_kws)): | |
| if len(time_columns) == 1: | |
| return time_columns[0] | |
| else: | |
| ''' | |
| return sorted(map(lambda t_col: (t_col ,len(findMaxSubString(t_col, res)) / len(t_col)), time_columns), | |
| key = lambda t2: t2[1])[::-1][0][0] | |
| ''' | |
| sort_list = sorted(map(lambda t_col: (t_col ,len(findMaxSubString(t_col, res)) / len(t_col)), time_columns), | |
| key = lambda t2: t2[1])[::-1] | |
| if sort_list: | |
| if sort_list[0]: | |
| return sort_list[0][0] | |
| return None | |
| c_res_common_dict = dict(filter(lambda t2: t2[1] ,map(lambda c: (c ,findMaxSubString(c, res)), header))) | |
| common_ratio_c_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(t2[0])), c_res_common_dict.items())) | |
| common_ratio_res_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(res)), c_res_common_dict.items())) | |
| #ic(decomp) | |
| #ic(common_ratio_c_dict) | |
| #ic(common_ratio_res_dict) | |
| if not common_ratio_c_dict or not common_ratio_res_dict: | |
| return None | |
| dict_0_max_key = sorted(common_ratio_c_dict.items(), key = lambda t2: t2[1])[::-1][0][0] | |
| dict_1_max_key = sorted(common_ratio_res_dict.items(), key = lambda t2: t2[1])[::-1][0][0] | |
| return dict_0_max_key if dict_0_max_key == dict_1_max_key else None | |
| ##### agg-classifier start | |
| ''' | |
| sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",) | |
| mean_kws = ('平均数', '均值', '平均值', '平均') | |
| max_kws = ('最大', '最多', '最大值', '最高') | |
| min_kws = ('最少', '最小值', '最小', '最低') | |
| sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",) | |
| max_special_kws = ("以上", "大于") | |
| min_special_kws = ("以下", "小于") | |
| ''' | |
| ###@@ sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",) | |
| ###@@ mean_kws = ('平均数', '均值', '平均值', '平均') | |
| ###@@ max_kws = ('最大', '最多', '最大值', '最高') | |
| ###@@ min_kws = ('最少', '最小值', '最小', '最低') | |
| ###@@ sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",) | |
| ###@@ max_special_kws = ("以上", "大于") | |
| ###@@ min_special_kws = ("以下", "小于") | |
| def simple_label_func(s, drop_header = True): | |
| text_tokens =s.question_cut | |
| header = list(map(lambda x: x[:x.find("(")] if (not x.startswith("(") and x.endswith(")")) else x ,s.header.split(","))) | |
| #### not contain samples may not match in fuzzy-match, special column mapping in finance, | |
| ### or "3" to "三" | |
| ''' | |
| fit_collection = ('多少个', '有几个', '总共') + ('总和','一共',) + ('平均数', '均值', '平均值', '平均') + ('最大', '最多', '最大值', '最高') + ('最少', '最小值', '最小', '最低') | |
| ''' | |
| fit_collection = sum_count_high_kws + mean_kws + max_kws + min_kws | |
| fit_header = [] | |
| for c in header: | |
| for kw in fit_collection: | |
| if kw in c: | |
| start_idx = c.find(kw) | |
| end_idx = start_idx + len(kw) | |
| fit_header.append(c[start_idx: end_idx]) | |
| if not drop_header: | |
| header = [] | |
| fit_header = [] | |
| input_ = "".join(text_tokens) | |
| for c in header + fit_header: | |
| if c in fit_collection: | |
| continue | |
| input_ = input_.replace(c, "") | |
| c0, c1 = c, c | |
| while c0 and c0 not in fit_collection and len(c0) >= 4: | |
| c0 = c0[1:] | |
| if c0 in fit_collection: | |
| break | |
| input_ = input_.replace(c0, "") | |
| while c1 and c1 not in fit_collection and len(c1) >= 4: | |
| c1 = c1[:-1] | |
| if c1 in fit_collection: | |
| break | |
| input_ = input_.replace(c1, "") | |
| #ic(input_) | |
| text_tokens = list(jieba.cut(input_)) | |
| #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) + ("哪些", "查", "数量") | |
| #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) | |
| ##### 高置信度部分 (作为是否构成使用特殊规则的判断标准) | |
| #### case 2 部分 (高置信度有效匹配) | |
| #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) | |
| #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",) | |
| cat_6_collection_high_level = sum_count_high_kws | |
| if any(map(lambda high_level_token: high_level_token in "".join(text_tokens), cat_6_collection_high_level)): | |
| return 6 | |
| #### 够深 够宽 规则部分, change order by header, if header have kws in , lower order | |
| if any(map(lambda kw: kw in text_tokens, mean_kws)): | |
| return 1 | |
| if any(map(lambda kw: kw in text_tokens, max_kws)): | |
| return 2 | |
| if any(map(lambda kw: kw in text_tokens, min_kws)): | |
| return 3 | |
| ##### 低置信度部分 | |
| #### case 2 部分 (低置信度尾部匹配) | |
| cat_6_collection = sum_count_low_kws | |
| if any(map(lambda kw: kw in text_tokens, cat_6_collection)): | |
| return 6 | |
| if any(map(lambda token: "几" in token, text_tokens)): | |
| return 6 | |
| #### special case 部分 | |
| if any(map(lambda kw: kw in text_tokens, max_special_kws)): | |
| return 2 | |
| if any(map(lambda kw: kw in text_tokens, min_special_kws)): | |
| return 3 | |
| #### 无效匹配 | |
| return 0 | |
| def simple_special_func(s, drop_header = True): | |
| text_tokens =s.question_cut | |
| header = list(map(lambda x: x[:x.find("(")] if (not x.startswith("(") and x.endswith(")")) else x ,s.header.split(","))) | |
| #### not contain samples may not match in fuzzy-match, special column mapping in finance, | |
| ### or "3" to "三" | |
| fit_collection = sum_count_high_kws + mean_kws + max_kws + min_kws | |
| fit_header = [] | |
| for c in header: | |
| for kw in fit_collection: | |
| if kw in c: | |
| start_idx = c.find(kw) | |
| end_idx = start_idx + len(kw) | |
| fit_header.append(c[start_idx: end_idx]) | |
| input_ = "".join(text_tokens) | |
| if not drop_header: | |
| header = [] | |
| fit_header = [] | |
| for c in header + fit_header: | |
| if c in fit_collection: | |
| continue | |
| input_ = input_.replace(c, "") | |
| c0, c1 = c, c | |
| while c0 and c0 not in fit_collection and len(c0) >= 4: | |
| c0 = c0[1:] | |
| if c0 in fit_collection: | |
| break | |
| input_ = input_.replace(c0, "") | |
| while c1 and c1 not in fit_collection and len(c1) >= 4: | |
| c1 = c1[:-1] | |
| if c1 in fit_collection: | |
| break | |
| input_ = input_.replace(c1, "") | |
| #ic(input_) | |
| text_tokens = list(jieba.cut(input_)) | |
| #ic(text_tokens) | |
| #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) + ("哪些", "查", "数量") | |
| #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) | |
| #### case 2 部分 (高置信度有效匹配) | |
| cat_6_collection_high_level = sum_count_high_kws | |
| if any(map(lambda high_level_token: high_level_token in "".join(text_tokens), cat_6_collection_high_level)): | |
| return 6 | |
| #### 够深 够宽 规则部分, change order by header, if header have kws in , lower order | |
| if any(map(lambda kw: kw in text_tokens, mean_kws)): | |
| return 1 | |
| if any(map(lambda kw: kw in text_tokens, max_kws)): | |
| return 2 | |
| if any(map(lambda kw: kw in text_tokens, min_kws)): | |
| return 3 | |
| return 0 | |
| def simple_total_label_func(s): | |
| is_special = False if simple_special_func(s) == 0 else True | |
| if not is_special: | |
| return 0 | |
| return simple_label_func(s) | |
| ## agg-classifier end | |
| ##### main block of process start | |
| def split_by_cond(s, extract_return = True): | |
| def recurrent_extract_cond(text): | |
| #return np.asarray(recurrent_extract(text)[0]) | |
| #return recurrent_extract(text)[0] | |
| return np.asarray(list(recurrent_extract(text)[0])) | |
| question = s["question"] | |
| res = s["rec_decomp"][1] | |
| if question is None: | |
| question = "" | |
| if res is None: | |
| res = "" | |
| common_res = findMaxSubString(question, res) | |
| #cond_kws = ("或", "而且", "并且", "当中") | |
| #cond_kws = ("或", "而且" "并且" "当中") | |
| cond_kws = conn_kws | |
| condition_part = question.replace(common_res, "") | |
| fit_kws = set([]) | |
| for kw in cond_kws: | |
| if kw in condition_part and not condition_part.startswith(kw) and not condition_part.endswith(kw): | |
| fit_kws.add(kw) | |
| if not fit_kws: | |
| will_return = ([condition_part.replace(" ", "") + " " + common_res], "") | |
| if extract_return: | |
| #return (list(map(recurrent_extract_cond, will_return[0])), will_return[1]) | |
| will_return = (np.asarray(list(map(recurrent_extract_cond, will_return[0]))) , will_return[1]) | |
| #will_return = (np.concatenate(list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist()))), axis = 0), will_return[1]) | |
| #will_return = (np.concatenate(list(map(np.asarray ,will_return[0].tolist())), axis = 0), will_return[1]) | |
| input_ = list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist()))) | |
| if input_: | |
| will_return = (np.concatenate(input_, axis = 0), will_return[1]) | |
| else: | |
| will_return = (np.empty((0, 3)), will_return[1]) | |
| will_return = will_return[0].reshape((-1, 3)) if will_return[0].size else np.empty((0, 3)) | |
| return will_return | |
| fit_kw = sorted(fit_kws, key = len)[::-1][0] | |
| condition_part_splits = condition_part.split(fit_kw) | |
| #if fit_kw in ("或",): | |
| if fit_kw in or_kws: | |
| fit_kw = "or" | |
| #elif fit_kw in ("而且", "并且", "当中",): | |
| elif fit_kw in and_kws: | |
| fit_kw = "and" | |
| else: | |
| fit_kw = "" | |
| will_return = (list(map(lambda cond_: cond_.replace(" ", "") + " " + common_res, condition_part_splits)), fit_kw) | |
| if extract_return: | |
| #return (list(map(recurrent_extract_cond, will_return[0])), will_return[1]) | |
| will_return = (np.asarray(list(map(recurrent_extract_cond, will_return[0]))), will_return[1]) | |
| #ic(will_return[0]) | |
| #will_return = (np.concatenate(list(map(np.asarray ,will_return[0].tolist())), axis = 0), will_return[1]) | |
| input_ = list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist()))) | |
| if input_: | |
| will_return = (np.concatenate(input_, axis = 0), will_return[1]) | |
| else: | |
| will_return = (np.empty((0, 3)), will_return[1]) | |
| #ic(will_return[0]) | |
| will_return = will_return[0].reshape((-1, 3)) if will_return[0].size else np.empty((0, 3)) | |
| return will_return | |
| def filter_total_conds(s, df, condition_fit_num = 0): | |
| assert condition_fit_num >= 0 and type(condition_fit_num) == type(0) | |
| df = df.copy() | |
| #### some col not as float with only "None" as cell, also transform them into float | |
| df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x) | |
| def justify_column_as_float(s): | |
| if "float" in str(s.dtype): | |
| return True | |
| if all(s.map(type).map(lambda tx: "float" in str(tx))): | |
| return True | |
| return False | |
| float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items()))) | |
| for f_col in float_cols: | |
| df[f_col] = df[f_col].astype(np.float64) | |
| ### | |
| header = df.columns.tolist() | |
| units_cols = filter(lambda c: "(" in c and c.endswith(")"), df.columns.tolist()) | |
| if not float_cols: | |
| float_discribe_df = pd.DataFrame() | |
| else: | |
| float_discribe_df = df[float_cols].describe() | |
| def call_eval(val): | |
| try: | |
| return literal_eval(val) | |
| except: | |
| return val | |
| #### find condition column same as question_column | |
| def find_cond_col(res, header): | |
| #ic(res, header) | |
| c_res_common_dict = dict(filter(lambda t2: t2[1] ,map(lambda c: (c ,findMaxSubString(c, res)), header))) | |
| #ic(c_res_common_dict) | |
| common_ratio_c_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(t2[0])), c_res_common_dict.items())) | |
| common_ratio_res_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(res)), c_res_common_dict.items())) | |
| if not common_ratio_c_dict or not common_ratio_res_dict: | |
| return None | |
| dict_0_max_key = sorted(common_ratio_c_dict.items(), key = lambda t2: t2[1])[::-1][0][0] | |
| dict_1_max_key = sorted(common_ratio_res_dict.items(), key = lambda t2: t2[1])[::-1][0][0] | |
| return dict_0_max_key if dict_0_max_key == dict_1_max_key else None | |
| ### | |
| #### type comptatible in column type and value type, and fit_num match | |
| def filter_cond_col(cond_t3): | |
| assert type(cond_t3) == type((1,)) and len(cond_t3) == 3 | |
| col, _, value = cond_t3 | |
| if type(value) == type(""): | |
| value = call_eval(value) | |
| if col not in df.columns.tolist(): | |
| return False | |
| #### type key value comp | |
| if col in float_cols and type(value) not in (type(0), type(0.0)): | |
| return False | |
| if col not in float_cols and type(value) in (type(0), type(0.0)): | |
| return False | |
| #### string value not in corr column | |
| if col not in float_cols and type(value) not in (type(0), type(0.0)): | |
| if type(value) == type(""): | |
| if value not in df[col].tolist(): | |
| return False | |
| if type(value) in (type(0), type(0.0)): | |
| if col in float_discribe_df.columns.tolist(): | |
| if condition_fit_num > 0: | |
| if value >= float_discribe_df[col].loc["min"] and value <= float_discribe_df[col].loc["max"]: | |
| return True | |
| else: | |
| return False | |
| else: | |
| assert condition_fit_num == 0 | |
| return True | |
| if condition_fit_num > 0: | |
| if value in df[col].tolist(): | |
| return True | |
| else: | |
| return False | |
| else: | |
| assert condition_fit_num == 0 | |
| return True | |
| return True | |
| ### | |
| #### condtions with same column may have conflict, choose the nearest one by stats in float or | |
| ### common_string as find_cond_col do. | |
| def same_column_cond_filter(cond_list, sort_stats = "mean"): | |
| #ic(cond_list) | |
| if len(cond_list) == len(set(map(lambda t3: t3[0] ,cond_list))): | |
| return cond_list | |
| req = defaultdict(list) | |
| for t3 in cond_list: | |
| req[t3[0]].append(t3[1:]) | |
| def t2_list_sort(col_name ,t2_list): | |
| if not t2_list: | |
| return None | |
| t2 = None | |
| if col_name in float_cols: | |
| t2 = sorted(t2_list, key = lambda t2: np.abs(t2[1] - float_discribe_df[col_name].loc[sort_stats]))[0] | |
| else: | |
| if all(map(lambda t2: type(t2[1]) == type("") ,t2_list)): | |
| col_val_cnt_df = df[col_name].value_counts().reset_index() | |
| col_val_cnt_df.columns = ["val", "cnt"] | |
| #col_val_cnt_df["val"].map(lambda x: sorted(filter(lambda tt2: tt2[-1] ,map(lambda t2: (t2 ,len(findMaxSubString(x, t2[1]))), t2_list)), | |
| # key = lambda ttt2: -1 * ttt2[-1])[0]) | |
| t2_list_map_to_column_val = list(filter(lambda x: x[1] ,map(lambda t2: (t2[0] ,find_cond_col(t2[1], list(set(col_val_cnt_df["val"].values.tolist())))), t2_list))) | |
| if t2_list_map_to_column_val: | |
| #### return max length fit val in column | |
| t2 = sorted(t2_list_map_to_column_val, key = lambda t2: -1 * len(t2[1]))[0] | |
| if t2 is None and t2_list: | |
| t2 = t2_list[0] | |
| return t2 | |
| cond_list_filtered = list(map(lambda ttt2: (ttt2[0], ttt2[1][0], ttt2[1][1]) , | |
| filter(lambda tt2: tt2[1] ,map(lambda t2: (t2[0] ,t2_list_sort(t2[0], t2[1])), req.items())))) | |
| return cond_list_filtered | |
| ### | |
| total_conds_map_to_column = list(map(lambda t3: (find_cond_col(t3[0], header), t3[1], t3[2]), s["total_conds"])) | |
| total_conds_map_to_column_filtered = list(filter(filter_cond_col, total_conds_map_to_column)) | |
| total_conds_map_to_column_filtered = sorted(set(map(lambda t3: (t3[0], t3[1], call_eval(t3[2]) if type(t3[2]) == type("") else t3[2]), total_conds_map_to_column_filtered))) | |
| #ic(total_conds_map_to_column_filtered) | |
| cp_cond_list = list(filter(lambda t3: t3[1] in (">", "<"), total_conds_map_to_column_filtered)) | |
| eq_cond_list = list(filter(lambda t3: t3[1] in ("==", "!="), total_conds_map_to_column_filtered)) | |
| cp_cond_list_filtered = same_column_cond_filter(cp_cond_list) | |
| #total_conds_map_to_column_filtered = same_column_cond_filter(total_conds_map_to_column_filtered) | |
| return cp_cond_list_filtered + eq_cond_list | |
| #return total_conds_map_to_column_filtered | |
| ###@@ only_kws_columns = {"城市": "=="} | |
| #### this function only use to cond can not extract by JointBert, | |
| ### may because not contain column string in question such as "城市" or difficult to extract kw | |
| ### define kw column as all cells in series are string type. | |
| ### this function support config relation to column and if future | |
| ### want to auto extract relation, this may can be done by head string or tail string by edit pattern "\w?{}\w?" | |
| ### "是" or "不是" can be extract in this manner. | |
| def augment_kw_in_question(question_df, df, only_kws_in_string = []): | |
| #### keep only_kws_in_string empty to maintain all condition | |
| question_df = question_df.copy() | |
| #df = df.copy() | |
| def call_eval(val): | |
| try: | |
| return literal_eval(val) | |
| except: | |
| return val | |
| df = df.copy() | |
| df = df.applymap(call_eval) | |
| #### some col not as float with only "None" as cell, also transform them into float | |
| df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x) | |
| def justify_column_as_float(s): | |
| if "float" in str(s.dtype): | |
| return True | |
| if all(s.map(type).map(lambda tx: "float" in str(tx))): | |
| return True | |
| return False | |
| float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items()))) | |
| #obj_cols = set(df.columns.tolist()).difference(set(float_cols)) | |
| def justify_column_as_kw(s): | |
| if all(s.map(type).map(lambda tx: "str" in str(tx))): | |
| return True | |
| return False | |
| obj_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_kw, axis = 0).to_dict().items()))) | |
| obj_cols = list(set(obj_cols).difference(set(float_cols))) | |
| if only_kws_columns: | |
| obj_cols = list(set(obj_cols).intersection(set(only_kws_columns.keys()))) | |
| #replace_format = "{}是{}" | |
| #kw_augmented_df = pd.DataFrame(df[obj_cols].apply(lambda s: list(map(lambda val :(val,replace_format.format(s.name, val)), s.tolist())), axis = 0).values.tolist()) | |
| #kw_augmented_df.columns = obj_cols | |
| kw_augmented_df = df[obj_cols].copy() | |
| #ic(kw_augmented_df) | |
| def extract_question_kws(question): | |
| if not kw_augmented_df.size: | |
| return [] | |
| req = defaultdict(set) | |
| for ridx, r in kw_augmented_df.iterrows(): | |
| for k, v in dict(r).items(): | |
| if v in question: | |
| pattern = "\w?{}\w?".format(v) | |
| all_match = re.findall(pattern, question) | |
| #req = req.union(set(all_match)) | |
| #req[v] = req[v].union(set(all_match)) | |
| key = "{}~{}".format(k, v) | |
| req[key] = req[key].union(set(all_match)) | |
| #ic(k, v) | |
| #question = question.replace(v[0], v[1]) | |
| #return question.replace(replace_format.format("","") * 2, replace_format.format("","")) | |
| #req = list(req) | |
| if only_kws_in_string: | |
| req = list(map(lambda tt2: tt2[0] ,filter(lambda t2: sum(map(lambda kw: sum(map(lambda t: kw in t ,t2[1])), only_kws_in_string)), req.items()))) | |
| else: | |
| req = list(set(req.keys())) | |
| def req_to_t3(req_string, relation = "=="): | |
| assert "~" in req_string | |
| left, right = req_string.split("~") | |
| if left in only_kws_columns: | |
| relation = only_kws_columns[left] | |
| return (left, relation, right) | |
| if not req: | |
| return [] | |
| return list(map(req_to_t3, req)) | |
| #return req | |
| question_df["question_kw_conds"] = question_df["question"].map(extract_question_kws) | |
| return question_df | |
| def choose_question_column_by_rm_conds(s, df): | |
| question = s.question | |
| total_conds_filtered = s.total_conds_filtered | |
| #cond_kws = ("或", "而且", "并且", "当中") | |
| cond_kws = conn_kws | |
| stopwords = ("是", ) | |
| #ic(total_conds_filtered) | |
| def construct_res(question): | |
| for k, _, v in total_conds_filtered: | |
| if "(" in k: | |
| k = k[:k.find("(")] | |
| #ic(k) | |
| question = question.replace(str(k), "") | |
| question = question.replace(str(v), "") | |
| for w in cond_kws + stopwords: | |
| question = question.replace(w, "") | |
| return question | |
| res = construct_res(question) | |
| decomp = (None, res, None) | |
| return choose_question_column(decomp, df.columns.tolist(), df) | |
| def split_qst_by_kw(question, kw = "的"): | |
| return question.split(kw) | |
| #qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个") | |
| ###@@ qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个") | |
| def choose_res_by_kws(question): | |
| #kws = ["的","多少", '是'] | |
| question = question.replace(" ", "") | |
| #kws = ["的","或者","或", "且","并且","同时"] | |
| kws = ("的",) + conn_kws | |
| kws = list(kws) | |
| def qst_kw_filter(text): | |
| #qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个") | |
| if any(map(lambda kw: kw in text, qst_kws)): | |
| return True | |
| return False | |
| kws_cp = deepcopy(kws) | |
| qst_c = set(question.split(",")) | |
| while kws: | |
| kw = kws.pop() | |
| qst_c = qst_c.union(set(filter(qst_kw_filter ,reduce(lambda a, b: a + b,map(lambda q: split_qst_by_kw(q, kw), qst_c)))) | |
| ) | |
| #print("-" * 10) | |
| #print(sorted(filter(lambda x: x and (x not in kws_cp) ,qst_c), key = len)) | |
| #print(sorted(filter(lambda x: x and (x not in kws_cp) and qst_kw_filter(x) ,qst_c), key = len)) | |
| #### final choose if or not | |
| return sorted(filter(lambda x: x and (x not in kws_cp) and qst_kw_filter(x) ,qst_c), key = len) | |
| #return sorted(filter(lambda x: x and (x not in kws_cp) and True ,qst_c), key = len) | |
| def cat6_to_45_by_column_type(s, df): | |
| agg_pred = s.agg_pred | |
| if agg_pred != 6: | |
| return agg_pred | |
| question_column = s.question_column | |
| def call_eval(val): | |
| try: | |
| return literal_eval(val) | |
| except: | |
| return val | |
| df = df.copy() | |
| df = df.applymap(call_eval) | |
| #### some col not as float with only "None" as cell, also transform them into float | |
| df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x) | |
| def justify_column_as_float(s): | |
| if "float" in str(s.dtype): | |
| return True | |
| if all(s.map(type).map(lambda tx: "float" in str(tx))): | |
| return True | |
| return False | |
| float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items()))) | |
| #obj_cols = set(df.columns.tolist()).difference(set(float_cols)) | |
| def justify_column_as_kw(s): | |
| if all(s.map(type).map(lambda tx: "str" in str(tx))): | |
| return True | |
| return False | |
| #obj_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_kw, axis = 0).to_dict().items()))) | |
| obj_cols = df.columns.tolist() | |
| obj_cols = list(set(obj_cols).difference(set(float_cols))) | |
| #ic(obj_cols, float_cols, df.columns.tolist()) | |
| assert len(obj_cols) + len(float_cols) == df.shape[1] | |
| if question_column in obj_cols: | |
| return 4 | |
| elif question_column in float_cols: | |
| return 5 | |
| else: | |
| return 0 | |
| def full_before_cat_decomp(df, question_df, only_req_columns = False): | |
| df, question_df = df.copy(), question_df.copy() | |
| first_train_question_extract_df = pd.DataFrame(question_df["question"].map(lambda question: (question, recurrent_extract(question))).tolist()) | |
| first_train_question_extract_df.columns = ["question", "decomp"] | |
| first_train_question_extract_df = augment_kw_in_question(first_train_question_extract_df, df) | |
| first_train_question_extract_df["rec_decomp"] = first_train_question_extract_df["decomp"].map(lambda decomp: decomp if decomp[0] else rec_more_time(decomp)) | |
| #return first_train_question_extract_df.copy() | |
| first_train_question_extract_df["question_cut"] = first_train_question_extract_df["rec_decomp"].map(lambda t3: jieba.cut(t3[1]) if t3[1] is not None else []).map(list) | |
| first_train_question_extract_df["header"] = ",".join(df.columns.tolist()) | |
| first_train_question_extract_df["question_column_res"] = first_train_question_extract_df["rec_decomp"].map(lambda decomp: choose_question_column(decomp, df.columns.tolist(), df)) | |
| #### agg | |
| first_train_question_extract_df["agg_res_pred"] = first_train_question_extract_df.apply(simple_total_label_func, axis = 1) | |
| first_train_question_extract_df["question_cut"] = first_train_question_extract_df["question"].map(jieba.cut).map(list) | |
| first_train_question_extract_df["agg_qst_pred"] = first_train_question_extract_df.apply(simple_total_label_func, axis = 1) | |
| ### if agg_res_pred and agg_qst_pred have conflict use max, to prevent from empty agg with errorous question column | |
| ### but this "max" can also be replaced by measure the performance of decomp part, and choose the best one | |
| ### or we can use agg_qst_pred with high balanced_score as 0.86+ in imbalanced dataset. | |
| ### which operation to use need some discussion. | |
| ### (balanced_accuracy_score(lookup_df["sql"], lookup_df["agg_pred"]), | |
| ### balanced_accuracy_score(lookup_df["sql"], lookup_df["agg_res_pred"]), | |
| ### balanced_accuracy_score(lookup_df["sql"], lookup_df["agg_qst_pred"])) | |
| ### (0.9444444444444445, 0.861111111111111, 0.9444444444444445) first_train_df conclucion | |
| ### (1.0, 0.8333333333333333, 1.0) cat6_conclucion | |
| ### this show that res worse in cat6 situation, but the agg-classifier construct is sufficent to have a | |
| ### good conclusion in full-question. (because cat6 is the most accurate part in Tupledire tree sense.) | |
| ### so use max as the best one | |
| first_train_question_extract_df["agg_pred"] = first_train_question_extract_df.apply(lambda s: max(s.agg_res_pred, s.agg_qst_pred), axis = 1) | |
| #### conn and conds | |
| first_train_question_extract_df["conds"] = first_train_question_extract_df["rec_decomp"].map(lambda x: x[0]) | |
| first_train_question_extract_df["split_conds"] = first_train_question_extract_df.apply(split_by_cond, axis = 1).values.tolist() | |
| first_train_question_extract_df["conn_pred"] = first_train_question_extract_df.apply(lambda s: split_by_cond(s, extract_return=False), axis = 1).map(lambda x: x[-1]).values.tolist() | |
| #first_train_question_extract_df["total_conds"] = first_train_question_extract_df.apply(lambda s: list(set(map(tuple,s["conds"] + s["split_conds"].tolist()))), axis = 1).values.tolist() | |
| first_train_question_extract_df["total_conds"] = first_train_question_extract_df.apply(lambda s: list(set(map(tuple,s["question_kw_conds"] + s["conds"] + s["split_conds"].tolist()))), axis = 1).values.tolist() | |
| first_train_question_extract_df["total_conds_filtered"] = first_train_question_extract_df.apply(lambda s: filter_total_conds(s, df), axis = 1).values.tolist() | |
| #### question_column_res more accurate, if not fit then use full-question question_column_qst to extract | |
| ### can not fit multi question or fuzzy describe, or question need kw replacement. | |
| #first_train_question_extract_df["question_column_qst"] = first_train_question_extract_df.apply(lambda s: choose_question_column_by_rm_conds(s, df), axis = 1) | |
| first_train_question_extract_df["question_column_qst"] = first_train_question_extract_df["question"].map(choose_res_by_kws).map(lambda res_list: list(filter(lambda x: x ,map(lambda res: choose_question_column((None, res, None), df.columns.tolist(), df), res_list)))).map(lambda x: x[0] if x else None) | |
| first_train_question_extract_df["question_column"] = first_train_question_extract_df.apply(lambda s: s.question_column_res if s.question_column_res is not None else s.question_column_qst, axis = 1) | |
| #### predict cat6 to 4 5 based on question_column and column dtype, | |
| #### this may performance bad if question_column has error, | |
| #### and almost 100% accurate if question_column truly provide and user is not an idoit (speak ....) | |
| agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"} | |
| first_train_question_extract_df["agg_pred"] = first_train_question_extract_df.apply(lambda s: cat6_to_45_by_column_type(s, df), axis = 1).map(lambda x: agg_sql_dict[x]) | |
| if only_req_columns: | |
| return first_train_question_extract_df[["question", | |
| "total_conds_filtered", | |
| "conn_pred", | |
| "question_column", | |
| "agg_pred" | |
| ]].copy() | |
| return first_train_question_extract_df.copy() | |
| if __name__ == "__main__": | |
| ###### valid block | |
| req = list(data_loader(req_table_num=None)) | |
| train_df, _ = req[2] | |
| train_df | |
| question = "哪些股票的收盘价大于20?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| #### not support select 股票 from table where 市值 = (select max(市值) from table) | |
| #### this is a nest sql. | |
| question = "哪个股票代码市值最高?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "市值的最大值是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "EPS大于0的股票有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "EPS大于0且周涨跌大于5的平均市值是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| train_df, _ = req[5] | |
| train_df | |
| question = "产能小于20、销量大于40而且市场占有率超过1的公司有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| #### 特殊符号 "、" | |
| question = "产能小于20而且销量大于40而且市场占有率超过1的公司有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| train_df, _ = req[6] | |
| train_df | |
| #### 加入列别名 只需要 复刻列即可 | |
| question = "投资评级为维持的名称有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| train_df["公司"] = train_df["名称"] | |
| question = "投资评级为维持的公司有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "投资评级为维持而且变动为增持的公司有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "投资评级为维持或者变动为增持的公司有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "投资评级为维持或者变动为增持的平均收盘价是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| train_df, _ = req[7] | |
| train_df | |
| question = "宁波的一手房每周交易数据上周成交量是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "一手房每周交易数据为宁波上周成交量是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| #### this also can deal with set column as use kw to extract | |
| ### see function augment_kw_in_question | |
| train_df["城市"] = train_df["一手房每周交易数据"] | |
| question = "一手房每周交易数据为宁波上周成交量是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "王翔知道宁波一手房的当月累计成交量是多少吗?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "王翔知道上周成交量大于50的最大同比当月是多少吗?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| train_df, _ = req[9] | |
| #### the last column should be "周跌幅", can't tackle duplicates columns | |
| train_df | |
| cols = train_df.columns.tolist() | |
| cols[-1] = "周跌幅(%)" | |
| train_df.columns = cols | |
| question = "周涨幅大于7的涨股有哪些?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| #### not recognize as 6 agg-classifier | |
| question = "周涨幅大于7的涨股总数是多少?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |
| question = "周涨幅大于7的涨股总共有多少个?" | |
| qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
| ic(question) | |
| ic(full_before_cat_decomp(train_df, qs_df, only_req_columns=True)) | |