Spaces:
Runtime error
Runtime error
| from langchain.document_loaders.unstructured import UnstructuredFileLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.vectorstores import Chroma | |
| from langchain.chains import RetrievalQA | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.schema import AIMessage, HumanMessage, SystemMessage, Document | |
| from langchain.document_loaders import PyPDFLoader | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| from retrieval.retrieval import Retrieval, BM25 | |
| import os, time, torch | |
| from torch.nn import Softmax | |
| class Agent: | |
| def __init__(self, args=None) -> None: | |
| self.args = args | |
| self.choices = args.choices | |
| self.corpus = Retrieval(k=args.choices) | |
| self.context_value = "" | |
| self.use_context = False | |
| self.softmax = Softmax(dim=1) | |
| self.temp = [] | |
| self.replace_list = torch.load('retrieval/replace.pt') | |
| print("Model is loading...") | |
| self.model = T5ForConditionalGeneration.from_pretrained(args.model).to(args.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | |
| print("Model loaded!") | |
| def load_context(self, doc_path): | |
| print('Loading file:', doc_path.name) | |
| if doc_path.name[-4:] == '.pdf': | |
| context = self.read_pdf(doc_path.name) | |
| else: | |
| # loader = UnstructuredFileLoader(doc_path.name) | |
| context = open(doc_path.name, encoding='utf-8').read() | |
| self.retrieval = Retrieval(docs=context) | |
| self.choices = self.retrieval.k | |
| self.use_context = True | |
| return f"Using file from {doc_path.name}" | |
| def asking(self, question): | |
| s_query = time.time() | |
| if self.use_context: | |
| print("Answering with your context:", question) | |
| contexts = self.retrieval.get_context(question) | |
| else: | |
| print("Answering without your context:", question) | |
| contexts = self.corpus.get_context(question) | |
| prompts = [] | |
| for context in contexts: | |
| prompt = f"Trả lời câu hỏi: {question} Trong nội dung: {context['context']}" | |
| prompts.append(prompt) | |
| s_token = time.time() | |
| tokens = self.tokenizer(prompts, max_length=self.args.seq_len, truncation=True, padding='max_length', return_tensors='pt') | |
| s_gen = time.time() | |
| outputs = self.model.generate( | |
| input_ids=tokens.input_ids.to(self.args.device), | |
| attention_mask=tokens.attention_mask.to(self.args.device), | |
| max_new_tokens=self.args.out_len, | |
| output_scores=True, | |
| return_dict_in_generate=True | |
| ) | |
| s_de = time.time() | |
| results = [] | |
| scores = self.softmax(outputs.scores[0]) | |
| scores = scores.max(dim=1).values*100 | |
| # print(scores) | |
| for i in range(self.choices): | |
| result = contexts[i] | |
| score = round(scores[i].item()) | |
| result['score'] = score | |
| answer = self.tokenizer.decode(outputs.sequences[i], skip_special_tokens=True) | |
| result['answer'] = answer | |
| results.append(result) | |
| def get_score(record): | |
| return record['score']**2 * record['score_bm'] | |
| results.sort(key=get_score, reverse=True) | |
| self.temp = results | |
| t_mess = "t_query: {:.2f}\t t_token: {:.2f}\t t_gen: {:.2f}\t t_decode: {:.2f}\t".format( | |
| s_token-s_query, s_gen-s_token, s_de-s_gen, time.time()-s_de | |
| ) | |
| print(t_mess, len(self.temp)) | |
| if results[0]['score'] > 50: | |
| return results[0]['answer'] | |
| else: | |
| return f"Tôi không chắc nhưng câu trả lời có thể là: {results[0]['answer']}\nBạn có thể tham khảo các câu trả lời bên cạnh!" | |
| def get_context(self, context): | |
| self.context_value = context | |
| self.retrieval = Retrieval(k=self.choices, docs=context) | |
| self.choices = self.retrieval.k | |
| self.use_context = True | |
| return context | |
| def load_context_file(self, file): | |
| print('Loading file:', file.name) | |
| text = '' | |
| for line in open(file.name, 'r', encoding='utf8'): | |
| text += line | |
| self.context_value = text | |
| return text | |
| def clear_context(self): | |
| self.context_value = "" | |
| self.use_context = False | |
| self.choices = self.args.choices | |
| return "" | |
| def replace(self, text): | |
| for key, value in self.replace_list: | |
| text = text.replace(key, value) | |
| return text | |
| def read_pdf(self, file_path): | |
| loader = PyPDFLoader(file_path) | |
| pages = loader.load_and_split() | |
| text = '' | |
| for page in pages: | |
| page_content = page.page_content | |
| text += self.replace(page_content) | |
| return text | |