import os import time import re from huggingface_hub import login import torch from transformers import AutoTokenizer, AutoModelForCausalLM from langdetect import detect from langchain.chains import RetrievalQA from langchain_community.llms import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain_community.document_loaders import TextLoader, PyPDFLoader from langchain.text_splitter import CharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from transcription_diarization import process_video import gradio as gr import plotly.graph_objs as go os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' hf_token = os.environ.get('hf_secret') if not hf_token: raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.") login(token=hf_token) def load_instructions(file_path): with open(file_path, 'r') as file: return file.read().strip() attachments_task = load_instructions("tasks/Attachments_task.txt") bigfive_task = load_instructions("tasks/BigFive_task.txt") personalities_task = load_instructions("tasks/Personalities_task.txt") def load_knowledge(file_path): loader = TextLoader(file_path) documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_documents(documents) return texts embeddings = HuggingFaceEmbeddings() attachments_db = FAISS.from_documents(load_knowledge("knowledge/bartholomew_attachments_definitions_no_items_no_in.txt"), embeddings) bigfive_db = FAISS.from_documents(load_knowledge("knowledge/bigfive_definitions_no_items.txt"), embeddings) personalities_db = FAISS.from_documents(load_knowledge("knowledge/personalities_definitions.txt"), embeddings) def detect_language(text): try: return detect(text) except: return "en" class SequentialAnalyzer: def __init__(self, hf_token): self.hf_token = hf_token self.model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" self.model = self.load_model() self.pipe = self.create_pipeline(self.model) def load_model(self): model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, device_map="auto", use_auth_token=self.hf_token, use_cache=False, load_in_4bit=False ) model.gradient_checkpointing_enable() return model def create_pipeline(self, model): from transformers import pipeline tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token) return pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.01, repetition_penalty=1.2, length_penalty=0.1, do_sample=False, truncation=True, bad_words_ids=[[tokenizer.encode(char, add_special_tokens=False)[0]] for char in "*"] ) def post_process_output(self, output): return re.sub(r'[*]', '', output).strip() def analyze_task(self, content, task, knowledge_db): tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token) input_tokens = len(tokenizer.encode(content)) max_input_length = 1548 encoded_input = tokenizer.encode(content, truncation=True, max_length=max_input_length) truncated_content = tokenizer.decode(encoded_input) if len(encoded_input) == max_input_length: print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.") llm = HuggingFacePipeline(pipeline=self.pipe) chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=knowledge_db.as_retriever(), chain_type_kwargs={"prompt": PromptTemplate( template=task + "\n\n{context}\n\n{question}\n\n-----------\n\nAnswer: ", input_variables=["context", "question"] )} ) result = chain({"query": truncated_content}) output = result['result'].split("-----------\n\nAnswer:")[-1].strip() cleaned_output = self.post_process_output(output) return cleaned_output, input_tokens def process_input(input_file, progress=None): start_time = time.time() def safe_progress(value, desc=""): if progress is not None: try: progress(value, desc=desc) except Exception as e: print(f"Progress update failed: {e}") safe_progress(0, desc="Processing file...") file_extension = os.path.splitext(input_file.name)[1].lower() if isinstance(input_file, str): file_path = input_file file_extension = os.path.splitext(file_path)[1].lower() else: file_path = input_file.name file_extension = os.path.splitext(file_path)[1].lower() if file_extension in ['.txt', '.srt']: with open(file_path, 'r', encoding='utf-8') as file: content = file.read() elif file_extension == '.pdf': loader = PyPDFLoader(file_path) pages = loader.load_and_split() content = '\n'.join([page.page_content for page in pages]) elif file_extension in ['.mp4', '.avi', '.mov']: safe_progress(0.2, desc="Processing video...") srt_path = process_video(file_path, hf_token, "en") with open(srt_path, 'r', encoding='utf-8') as file: content = file.read() os.remove(srt_path) else: return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None, None, None, None detected_language = detect_language(content) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", use_auth_token=hf_token) original_tokens = len(tokenizer.encode(content)) safe_progress(0.4, desc="Analyzing content...") analyzer = SequentialAnalyzer(hf_token) safe_progress(0.5, desc="Analyzing attachments...") attachments_answer, attachments_tokens = analyzer.analyze_task(content, attachments_task, attachments_db) print("Attachments output:\n", attachments_answer) print(f"Attachments input tokens (before truncation): {attachments_tokens}") safe_progress(0.7, desc="Analyzing Big Five traits...") bigfive_answer, bigfive_tokens = analyzer.analyze_task(content, bigfive_task, bigfive_db) print("Big Five output:\n", bigfive_answer) print(f"Big Five input tokens (before truncation): {bigfive_tokens}") safe_progress(0.9, desc="Analyzing personalities...") personalities_answer, personalities_tokens = analyzer.analyze_task(content, personalities_task, personalities_db) print("Personalities output:\n", personalities_answer) print(f"Personalities input tokens (before truncation): {personalities_tokens}") end_time = time.time() execution_time = end_time - start_time execution_info = f"{execution_time:.2f} seconds" safe_progress(1.0, desc="Analysis complete!") return ("Analysis complete!", execution_info, detected_language, attachments_answer, bigfive_answer, personalities_answer, original_tokens, attachments_tokens, bigfive_tokens, personalities_tokens)