reab5555's picture
Update processing.py
d6bf13f verified
raw
history blame
7.62 kB
import os
import time
import re
from huggingface_hub import login
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langdetect import detect
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import TextLoader, PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transcription_diarization import process_video
import gradio as gr
import plotly.graph_objs as go
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
hf_token = os.environ.get('hf_secret')
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.")
login(token=hf_token)
def load_instructions(file_path):
with open(file_path, 'r') as file:
return file.read().strip()
attachments_task = load_instructions("tasks/Attachments_task.txt")
bigfive_task = load_instructions("tasks/BigFive_task.txt")
personalities_task = load_instructions("tasks/Personalities_task.txt")
def load_knowledge(file_path):
loader = TextLoader(file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
return texts
embeddings = HuggingFaceEmbeddings()
attachments_db = FAISS.from_documents(load_knowledge("knowledge/bartholomew_attachments_definitions_no_items_no_in.txt"), embeddings)
bigfive_db = FAISS.from_documents(load_knowledge("knowledge/bigfive_definitions_no_items.txt"), embeddings)
personalities_db = FAISS.from_documents(load_knowledge("knowledge/personalities_definitions.txt"), embeddings)
def detect_language(text):
try:
return detect(text)
except:
return "en"
class SequentialAnalyzer:
def __init__(self, hf_token):
self.hf_token = hf_token
self.model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
self.model = self.load_model()
self.pipe = self.create_pipeline(self.model)
def load_model(self):
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
use_auth_token=self.hf_token,
use_cache=False,
load_in_4bit=False
)
model.gradient_checkpointing_enable()
return model
def create_pipeline(self, model):
from transformers import pipeline
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token)
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.01,
repetition_penalty=1.2,
length_penalty=0.1,
do_sample=False,
truncation=True,
bad_words_ids=[[tokenizer.encode(char, add_special_tokens=False)[0]] for char in "*"]
)
def post_process_output(self, output):
return re.sub(r'[*]', '', output).strip()
def analyze_task(self, content, task, knowledge_db):
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token)
input_tokens = len(tokenizer.encode(content))
max_input_length = 1548
encoded_input = tokenizer.encode(content, truncation=True, max_length=max_input_length)
truncated_content = tokenizer.decode(encoded_input)
if len(encoded_input) == max_input_length:
print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.")
llm = HuggingFacePipeline(pipeline=self.pipe)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=knowledge_db.as_retriever(),
chain_type_kwargs={"prompt": PromptTemplate(
template=task + "\n\n{context}\n\n{question}\n\n-----------\n\nAnswer: ",
input_variables=["context", "question"]
)}
)
result = chain({"query": truncated_content})
output = result['result'].split("-----------\n\nAnswer:")[-1].strip()
cleaned_output = self.post_process_output(output)
return cleaned_output, input_tokens
def process_input(input_file, progress=None):
start_time = time.time()
def safe_progress(value, desc=""):
if progress is not None:
try:
progress(value, desc=desc)
except Exception as e:
print(f"Progress update failed: {e}")
safe_progress(0, desc="Processing file...")
file_extension = os.path.splitext(input_file.name)[1].lower()
if isinstance(input_file, str):
file_path = input_file
file_extension = os.path.splitext(file_path)[1].lower()
else:
file_path = input_file.name
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension in ['.txt', '.srt']:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
elif file_extension == '.pdf':
loader = PyPDFLoader(file_path)
pages = loader.load_and_split()
content = '\n'.join([page.page_content for page in pages])
elif file_extension in ['.mp4', '.avi', '.mov']:
safe_progress(0.2, desc="Processing video...")
srt_path = process_video(file_path, hf_token, "en")
with open(srt_path, 'r', encoding='utf-8') as file:
content = file.read()
os.remove(srt_path)
else:
return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None, None, None, None
detected_language = detect_language(content)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", use_auth_token=hf_token)
original_tokens = len(tokenizer.encode(content))
safe_progress(0.4, desc="Analyzing content...")
analyzer = SequentialAnalyzer(hf_token)
safe_progress(0.5, desc="Analyzing attachments...")
attachments_answer, attachments_tokens = analyzer.analyze_task(content, attachments_task, attachments_db)
print("Attachments output:\n", attachments_answer)
print(f"Attachments input tokens (before truncation): {attachments_tokens}")
safe_progress(0.7, desc="Analyzing Big Five traits...")
bigfive_answer, bigfive_tokens = analyzer.analyze_task(content, bigfive_task, bigfive_db)
print("Big Five output:\n", bigfive_answer)
print(f"Big Five input tokens (before truncation): {bigfive_tokens}")
safe_progress(0.9, desc="Analyzing personalities...")
personalities_answer, personalities_tokens = analyzer.analyze_task(content, personalities_task, personalities_db)
print("Personalities output:\n", personalities_answer)
print(f"Personalities input tokens (before truncation): {personalities_tokens}")
end_time = time.time()
execution_time = end_time - start_time
execution_info = f"{execution_time:.2f} seconds"
safe_progress(1.0, desc="Analysis complete!")
return ("Analysis complete!", execution_info, detected_language,
attachments_answer, bigfive_answer, personalities_answer,
original_tokens, attachments_tokens, bigfive_tokens, personalities_tokens)