Spaces:
Runtime error
Runtime error
import os | |
import time | |
import re | |
from huggingface_hub import login | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from langdetect import detect | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import TextLoader, PyPDFLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from transcription_diarization import process_video | |
import gradio as gr | |
import plotly.graph_objs as go | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
hf_token = os.environ.get('hf_secret') | |
if not hf_token: | |
raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.") | |
login(token=hf_token) | |
def load_instructions(file_path): | |
with open(file_path, 'r') as file: | |
return file.read().strip() | |
attachments_task = load_instructions("tasks/Attachments_task.txt") | |
bigfive_task = load_instructions("tasks/BigFive_task.txt") | |
personalities_task = load_instructions("tasks/Personalities_task.txt") | |
def load_knowledge(file_path): | |
loader = TextLoader(file_path) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_documents(documents) | |
return texts | |
embeddings = HuggingFaceEmbeddings() | |
attachments_db = FAISS.from_documents(load_knowledge("knowledge/bartholomew_attachments_definitions_no_items_no_in.txt"), embeddings) | |
bigfive_db = FAISS.from_documents(load_knowledge("knowledge/bigfive_definitions_no_items.txt"), embeddings) | |
personalities_db = FAISS.from_documents(load_knowledge("knowledge/personalities_definitions.txt"), embeddings) | |
def detect_language(text): | |
try: | |
return detect(text) | |
except: | |
return "en" | |
class SequentialAnalyzer: | |
def __init__(self, hf_token): | |
self.hf_token = hf_token | |
self.model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
self.model = self.load_model() | |
self.pipe = self.create_pipeline(self.model) | |
def load_model(self): | |
model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
use_auth_token=self.hf_token, | |
use_cache=False, | |
load_in_4bit=False | |
) | |
model.gradient_checkpointing_enable() | |
return model | |
def create_pipeline(self, model): | |
from transformers import pipeline | |
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token) | |
return pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.01, | |
repetition_penalty=1.2, | |
length_penalty=0.1, | |
do_sample=False, | |
truncation=True, | |
bad_words_ids=[[tokenizer.encode(char, add_special_tokens=False)[0]] for char in "*"] | |
) | |
def post_process_output(self, output): | |
return re.sub(r'[*]', '', output).strip() | |
def analyze_task(self, content, task, knowledge_db): | |
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token) | |
input_tokens = len(tokenizer.encode(content)) | |
max_input_length = 1548 | |
encoded_input = tokenizer.encode(content, truncation=True, max_length=max_input_length) | |
truncated_content = tokenizer.decode(encoded_input) | |
if len(encoded_input) == max_input_length: | |
print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.") | |
llm = HuggingFacePipeline(pipeline=self.pipe) | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=knowledge_db.as_retriever(), | |
chain_type_kwargs={"prompt": PromptTemplate( | |
template=task + "\n\n{context}\n\n{question}\n\n-----------\n\nAnswer: ", | |
input_variables=["context", "question"] | |
)} | |
) | |
result = chain({"query": truncated_content}) | |
output = result['result'].split("-----------\n\nAnswer:")[-1].strip() | |
cleaned_output = self.post_process_output(output) | |
return cleaned_output, input_tokens | |
def process_input(input_file, progress=None): | |
start_time = time.time() | |
def safe_progress(value, desc=""): | |
if progress is not None: | |
try: | |
progress(value, desc=desc) | |
except Exception as e: | |
print(f"Progress update failed: {e}") | |
safe_progress(0, desc="Processing file...") | |
file_extension = os.path.splitext(input_file.name)[1].lower() | |
if isinstance(input_file, str): | |
file_path = input_file | |
file_extension = os.path.splitext(file_path)[1].lower() | |
else: | |
file_path = input_file.name | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension in ['.txt', '.srt']: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
elif file_extension == '.pdf': | |
loader = PyPDFLoader(file_path) | |
pages = loader.load_and_split() | |
content = '\n'.join([page.page_content for page in pages]) | |
elif file_extension in ['.mp4', '.avi', '.mov']: | |
safe_progress(0.2, desc="Processing video...") | |
srt_path = process_video(file_path, hf_token, "en") | |
with open(srt_path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
os.remove(srt_path) | |
else: | |
return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None, None, None, None | |
detected_language = detect_language(content) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", use_auth_token=hf_token) | |
original_tokens = len(tokenizer.encode(content)) | |
safe_progress(0.4, desc="Analyzing content...") | |
analyzer = SequentialAnalyzer(hf_token) | |
safe_progress(0.5, desc="Analyzing attachments...") | |
attachments_answer, attachments_tokens = analyzer.analyze_task(content, attachments_task, attachments_db) | |
print("Attachments output:\n", attachments_answer) | |
print(f"Attachments input tokens (before truncation): {attachments_tokens}") | |
safe_progress(0.7, desc="Analyzing Big Five traits...") | |
bigfive_answer, bigfive_tokens = analyzer.analyze_task(content, bigfive_task, bigfive_db) | |
print("Big Five output:\n", bigfive_answer) | |
print(f"Big Five input tokens (before truncation): {bigfive_tokens}") | |
safe_progress(0.9, desc="Analyzing personalities...") | |
personalities_answer, personalities_tokens = analyzer.analyze_task(content, personalities_task, personalities_db) | |
print("Personalities output:\n", personalities_answer) | |
print(f"Personalities input tokens (before truncation): {personalities_tokens}") | |
end_time = time.time() | |
execution_time = end_time - start_time | |
execution_info = f"{execution_time:.2f} seconds" | |
safe_progress(1.0, desc="Analysis complete!") | |
return ("Analysis complete!", execution_info, detected_language, | |
attachments_answer, bigfive_answer, personalities_answer, | |
original_tokens, attachments_tokens, bigfive_tokens, personalities_tokens) |