Spaces:
Runtime error
Runtime error
| import torch, math | |
| from pyvi.ViTokenizer import tokenize | |
| import re, os, string | |
| import pandas as pd | |
| import math | |
| import numpy as np | |
| class BM25: | |
| def __init__(self, k1=1.5, b=0.75): | |
| self.b = b | |
| self.k1 = k1 | |
| def fit(self, corpus): | |
| """ | |
| Fit the various statistics that are required to calculate BM25 ranking | |
| score using the corpus given. | |
| Parameters | |
| ---------- | |
| corpus : list[list[str]] | |
| Each element in the list represents a document, and each document | |
| is a list of the terms. | |
| Returns | |
| ------- | |
| self | |
| """ | |
| tf = [] | |
| df = {} | |
| idf = {} | |
| doc_len = [] | |
| corpus_size = 0 | |
| for document in corpus: | |
| corpus_size += 1 | |
| doc_len.append(len(document)) | |
| # compute tf (term frequency) per document | |
| frequencies = {} | |
| for term in document: | |
| term_count = frequencies.get(term, 0) + 1 | |
| frequencies[term] = term_count | |
| tf.append(frequencies) | |
| # compute df (document frequency) per term | |
| for term, _ in frequencies.items(): | |
| df_count = df.get(term, 0) + 1 | |
| df[term] = df_count | |
| for term, freq in df.items(): | |
| idf[term] = math.log(1 + (corpus_size - freq + 0.5) / (freq + 0.5)) | |
| self.tf_ = tf | |
| self.df_ = df | |
| self.idf_ = idf | |
| self.doc_len_ = doc_len | |
| self.corpus_ = corpus | |
| self.corpus_size_ = corpus_size | |
| self.avg_doc_len_ = sum(doc_len) / corpus_size | |
| return self | |
| def search(self, query): | |
| scores = [self._score(query, index) for index in range(self.corpus_size_)] | |
| return scores | |
| def _score(self, query, index): | |
| score = 0.0 | |
| doc_len = self.doc_len_[index] | |
| frequencies = self.tf_[index] | |
| for term in query: | |
| if term not in frequencies: | |
| continue | |
| freq = frequencies[term] | |
| numerator = self.idf_[term] * freq * (self.k1 + 1) | |
| denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len_) | |
| score += (numerator / denominator) | |
| return score | |
| class Retrieval: | |
| def __init__( | |
| self, k=8, | |
| model='retrieval/bm25.pt', | |
| contexts='retrieval/context.pt', | |
| stop_words='retrieval/stopwords.csv', | |
| max_len = 400, | |
| docs = None | |
| ) -> None: | |
| self.k = k | |
| self.max_len = max_len | |
| data = pd.read_csv(stop_words, sep="\t", encoding='utf-8') | |
| self.list_stopwords = data['stopwords'] | |
| if docs: | |
| self.tuning(docs) | |
| else: | |
| self.bm25 = torch.load(model) | |
| self.contexts = torch.load(contexts) | |
| def get_context(self, query='Chảy máu chân răng là bệnh gì?'): | |
| def clean_text(text): | |
| text = re.sub('<.*?>', '', text).strip() | |
| text = re.sub('(\s)+', r'\1', text) | |
| return text | |
| def normalize_text(text): | |
| listpunctuation = string.punctuation.replace('_', '') | |
| for i in listpunctuation: | |
| text = text.replace(i, ' ') | |
| return text.lower() | |
| def remove_stopword(text): | |
| pre_text = [] | |
| words = text.split() | |
| for word in words: | |
| if word not in self.list_stopwords: | |
| pre_text.append(word) | |
| text2 = ' '.join(pre_text) | |
| return text2 | |
| def word_segment(sent): | |
| sent = tokenize(sent.encode('utf-8').decode('utf-8')) | |
| return sent | |
| query = clean_text(query) | |
| query = word_segment(query) | |
| query = remove_stopword(normalize_text(query)) | |
| query = query.split() | |
| scores = self.bm25.search(query) | |
| scores_index = np.argsort(scores) | |
| results = [] | |
| for k in range(1, self.k+1): | |
| index = scores_index[-k] | |
| result = {'score':scores[index], 'index':index, 'context':self.contexts[index]} | |
| results.append(result) | |
| return results | |
| def split(self, document): | |
| document = document.replace('\n', ' ') | |
| document = re.sub(' +', ' ', document) | |
| sentences = document.split('. ') | |
| context_list = [] | |
| context = "" | |
| length = 0 | |
| pre = "" | |
| len__ = 0 | |
| for sentence in sentences: | |
| sentence += '. ' | |
| len_ = len(sentence.split()) | |
| if length + len_ > self.max_len: | |
| context_list.append(context) | |
| context = pre | |
| length = len__ | |
| length += len_ | |
| context += sentence | |
| pre = sentence | |
| len__ = len_ | |
| context_list.append(context) | |
| self.contexts = context_list | |
| if len(context_list) < self.k: | |
| self.k = len(context_list) | |
| def tuning(self, document): | |
| def clean_text(text): | |
| text = re.sub('<.*?>', '', text).strip() | |
| text = re.sub('(\s)+', r'\1', text) | |
| return text | |
| def normalize_text(text): | |
| listpunctuation = string.punctuation.replace('_', '') | |
| for i in listpunctuation: | |
| text = text.replace(i, ' ') | |
| return text.lower() | |
| def remove_stopword(text): | |
| pre_text = [] | |
| words = text.split() | |
| for word in words: | |
| if word not in self.list_stopwords: | |
| pre_text.append(word) | |
| text2 = ' '.join(pre_text) | |
| return text2 | |
| def word_segment(sent): | |
| sent = tokenize(sent.encode('utf-8').decode('utf-8')) | |
| return sent | |
| self.split(document) | |
| docs = [] | |
| for content in self.contexts: | |
| content = clean_text(content) | |
| content = word_segment(content) | |
| content = remove_stopword(normalize_text(content)) | |
| docs.append(content) | |
| print('There is', len(docs), 'contexts') | |
| texts = [ | |
| [word for word in document.lower().split() if word not in self.list_stopwords] | |
| for document in docs | |
| ] | |
| self.bm25 = BM25() | |
| self.bm25.fit(texts) | |