reab5555 commited on
Commit
76c5624
·
verified ·
1 Parent(s): 4ca1014

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +91 -222
processing.py CHANGED
@@ -1,233 +1,102 @@
1
- import os
2
- import time
3
- import re
4
- import numpy as np
5
- from huggingface_hub import login
6
- import torch
7
- import random
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- from langdetect import detect
10
- from langchain.chains import RetrievalQA
11
- from langchain_community.llms import HuggingFacePipeline
12
  from langchain.prompts import PromptTemplate
13
- from langchain_community.document_loaders import TextLoader, PyPDFLoader
14
- from langchain.text_splitter import CharacterTextSplitter
15
- from langchain_community.vectorstores import FAISS
16
- from langchain_community.embeddings import HuggingFaceEmbeddings
17
- from transcription_diarization import process_video
18
- from output_parser import get_prompt_template, attachment_parser, bigfive_parser, personality_parser, parse_analysis_output
19
 
20
- hf_token = os.environ.get('hf_secret')
21
- if not hf_token:
22
- raise ValueError("HF_TOKEN not found in environment variables. Please set it in the Space secrets.")
 
 
 
 
 
 
 
 
23
 
24
- login(token=hf_token)
 
 
 
 
 
 
 
25
 
26
- def load_instructions(file_path):
27
- with open(file_path, 'r') as file:
28
- return file.read().strip()
 
 
 
 
 
 
 
 
 
29
 
30
- def load_knowledge(file_path):
31
- loader = TextLoader(file_path)
32
- documents = loader.load()
33
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
34
- texts = text_splitter.split_documents(documents)
35
- return texts
 
 
 
 
 
 
36
 
37
- general_task = load_instructions("tasks/general_task.txt")
38
- attachments_task = load_instructions("tasks/Attachments_task.txt")
39
- bigfive_task = load_instructions("tasks/BigFive_task.txt")
40
- personalities_task = load_instructions("tasks/Personalities_task.txt")
 
 
 
 
 
41
 
42
- embeddings = HuggingFaceEmbeddings()
43
- attachments_db = FAISS.from_documents(load_knowledge("knowledge/bartholomew_attachments_definitions_no_items_no_in.txt"), embeddings)
44
- bigfive_db = FAISS.from_documents(load_knowledge("knowledge/bigfive_definitions_no_items.txt"), embeddings)
45
- personalities_db = FAISS.from_documents(load_knowledge("knowledge/personalities_definitions.txt"), embeddings)
 
 
 
 
 
 
 
 
 
46
 
47
- def detect_language(text):
48
- try:
49
- return detect(text)
50
- except:
51
- return "en"
52
 
53
- class SequentialAnalyzer:
54
- def __init__(self, hf_token, seed=42):
55
- self.hf_token = hf_token
56
- self.model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
57
- self.set_seed(seed)
58
- self.model = self.load_model()
59
- self.pipe = self.create_pipeline(self.model)
60
-
61
- def set_seed(self, seed):
62
- random.seed(seed)
63
- np.random.seed(seed)
64
- torch.manual_seed(seed)
65
- if torch.cuda.is_available():
66
- torch.cuda.manual_seed_all(seed)
67
-
68
- def load_model(self):
69
- model = AutoModelForCausalLM.from_pretrained(
70
- self.model_name,
71
- torch_dtype=torch.bfloat16,
72
- device_map="auto",
73
- use_auth_token=self.hf_token,
74
- use_cache=False,
75
- load_in_4bit=False
76
- )
77
- return model
78
-
79
- def create_pipeline(self, model):
80
- from transformers import pipeline
81
- tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token)
82
- return pipeline(
83
- "text-generation",
84
- model=model,
85
- top_k=50,
86
- top_p=0.8,
87
- tokenizer=tokenizer,
88
- max_new_tokens=512,
89
- temperature=0.3,
90
- repetition_penalty=1.2,
91
- do_sample=False,
92
- truncation=True,
93
- bad_words_ids=[[tokenizer.encode(char, add_special_tokens=False)[0]] for char in "*"]
94
- )
95
-
96
- def post_process_output(self, output):
97
- return re.sub(r'[*]', '', output).strip()
98
-
99
- def analyze_task(self, content, task, knowledge_db, analysis_type):
100
- tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token)
101
-
102
- input_tokens = len(tokenizer.encode(content))
103
-
104
- max_input_length = 800
105
- encoded_input = tokenizer.encode(content, truncation=True, max_length=max_input_length)
106
- truncated_content = tokenizer.decode(encoded_input)
107
-
108
- if len(encoded_input) == max_input_length:
109
- print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.")
110
-
111
- llm = HuggingFacePipeline(pipeline=self.pipe)
112
-
113
- if analysis_type == "attachments":
114
- parser = attachment_parser
115
- elif analysis_type == "bigfive":
116
- parser = bigfive_parser
117
- elif analysis_type == "personalities":
118
- parser = personality_parser
119
- else:
120
- raise ValueError(f"Unknown analysis type: {analysis_type}")
121
-
122
- prompt = get_prompt_template(task, parser)
123
-
124
- if knowledge_db is None:
125
- chain = prompt | llm
126
- result = chain.invoke({"text": truncated_content})
127
- output = result
128
- else:
129
- chain = RetrievalQA.from_chain_type(
130
- llm=llm,
131
- chain_type="stuff",
132
- retriever=knowledge_db.as_retriever(),
133
- chain_type_kwargs={
134
- "prompt": PromptTemplate(
135
- template=task + "\n\n{context}\n\n{question}\n\n" + parser.get_format_instructions() + "\n\nAnalysis:",
136
- input_variables=["context", "question"]
137
- )
138
- }
139
- )
140
- result = chain.run(truncated_content)
141
- output = result
142
-
143
- print(f"Raw model output: {output}")
144
-
145
- try:
146
- cleaned_output = self.post_process_output(output)
147
- parsed_output = parser.parse(cleaned_output)
148
- except Exception as e:
149
- raise ValueError(f"Error parsing output: {e}")
150
-
151
- # Check if all required keys are present
152
- required_keys = {schema.name for schema in parser.response_schemas}
153
- missing_keys = required_keys - parsed_output.keys()
154
-
155
- if missing_keys:
156
- raise ValueError(f"Missing some input keys: {missing_keys}")
157
-
158
- return cleaned_output, input_tokens
159
 
160
- def process_input(input_file, max_speakers, progress=None):
161
- start_time = time.time()
162
-
163
- def safe_progress(value, desc=""):
164
- if progress is not None:
165
- try:
166
- progress(value, desc=desc)
167
- except Exception as e:
168
- print(f"Progress update failed: {e}")
169
-
170
- safe_progress(0, desc="Processing file")
171
-
172
- if isinstance(input_file, str):
173
- file_path = input_file
174
  else:
175
- file_path = input_file.name
176
-
177
- file_extension = os.path.splitext(file_path)[1].lower()
178
-
179
- if file_extension in ['.txt', '.srt']:
180
- with open(file_path, 'r', encoding='utf-8') as file:
181
- content = file.read()
182
- transcription = content
183
- elif file_extension == '.pdf':
184
- loader = PyPDFLoader(file_path)
185
- pages = loader.load_and_split()
186
- content = '\n'.join([page.page_content for page in pages])
187
- transcription = content
188
- elif file_extension in ['.mp4', '.avi', '.mov']:
189
- safe_progress(0.2, desc="Processing video...")
190
- srt_path = process_video(file_path, hf_token, "en", max_speakers)
191
- with open(srt_path, 'r', encoding='utf-8') as file:
192
- content = file.read()
193
- transcription = content
194
- os.remove(srt_path)
195
- else:
196
- return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None
197
-
198
- detected_language = detect_language(content)
199
-
200
- safe_progress(0.2, desc="Initializing analyzer")
201
- analyzer = SequentialAnalyzer(hf_token)
202
-
203
- tasks = [
204
- ("General + Attachments", general_task + "\n\n" + attachments_task, attachments_db, "attachments"),
205
- ("General + Big Five", general_task + "\n\n" + bigfive_task, bigfive_db, "bigfive"),
206
- ("General + Personalities", general_task + "\n\n" + personalities_task, personalities_db, "personalities")
207
- ]
208
-
209
- results = []
210
- tokens = []
211
-
212
- for i, (task_name, task, db, analysis_type) in enumerate(tasks):
213
- safe_progress((i + 1) * 0.2, desc=f"Analyzing {task_name}")
214
- answer, task_tokens = analyzer.analyze_task(content, task, db, analysis_type)
215
- results.append((answer, analysis_type))
216
- tokens.append(task_tokens)
217
-
218
- end_time = time.time()
219
- execution_time = end_time - start_time
220
-
221
- safe_progress(1.0, desc="Analysis complete!")
222
-
223
- parsed_results = [parse_analysis_output(result, analysis_type) for result, analysis_type in results]
224
-
225
- return (
226
- "Analysis complete!",
227
- f"{execution_time:.2f} seconds",
228
- detected_language,
229
- parsed_results[0], # attachments
230
- parsed_results[1], # bigfive
231
- parsed_results[2], # personalities,
232
- transcription
233
- )
 
1
+ from langchain.output_parsers import StructuredOutputParser, ResponseSchema
 
 
 
 
 
 
 
 
 
 
2
  from langchain.prompts import PromptTemplate
3
+ from pydantic import BaseModel
4
+ from typing import Dict
 
 
 
 
5
 
6
+ class AttachmentStyle(BaseModel):
7
+ speaker: str
8
+ secured: float
9
+ anxious_preoccupied: float
10
+ dismissive_avoidant: float
11
+ fearful_avoidant: float
12
+ self_rating: int
13
+ others_rating: int
14
+ anxiety: int
15
+ avoidance: int
16
+ explanation: str
17
 
18
+ class BigFiveTraits(BaseModel):
19
+ speaker: str
20
+ extraversion: int
21
+ agreeableness: int
22
+ conscientiousness: int
23
+ neuroticism: int
24
+ openness: int
25
+ explanation: str
26
 
27
+ class PersonalityDisorder(BaseModel):
28
+ speaker: str
29
+ depressed: int
30
+ paranoid: int
31
+ schizoid_schizotypal: int
32
+ antisocial_psychopathic: int
33
+ borderline_dysregulated: int
34
+ narcissistic: int
35
+ anxious_avoidant: int
36
+ dependent_victimized: int
37
+ obsessional: int
38
+ explanation: str
39
 
40
+ attachment_response_schemas = [
41
+ ResponseSchema(name="speaker", description="The name or number of the speaker"),
42
+ ResponseSchema(name="secured", description="Probability of secured attachment style (0-1)"),
43
+ ResponseSchema(name="anxious_preoccupied", description="Probability of anxious-preoccupied attachment style (0-1)"),
44
+ ResponseSchema(name="dismissive_avoidant", description="Probability of dismissive-avoidant attachment style (0-1)"),
45
+ ResponseSchema(name="fearful_avoidant", description="Probability of fearful-avoidant attachment style (0-1)"),
46
+ ResponseSchema(name="self_rating", description="Self rating (0-10)"),
47
+ ResponseSchema(name="others_rating", description="Others rating (0-10)"),
48
+ ResponseSchema(name="anxiety", description="Anxiety rating (0-10)"),
49
+ ResponseSchema(name="avoidance", description="Avoidance rating (0-10)"),
50
+ ResponseSchema(name="explanation", description="Brief explanation of the attachment style")
51
+ ]
52
 
53
+ bigfive_response_schemas = [
54
+ ResponseSchema(name="speaker", description="The name or number of the speaker"),
55
+ ResponseSchema(name="extraversion", description="Extraversion rating (-10 to 10)"),
56
+ ResponseSchema(name="agreeableness", description="Agreeableness rating (-10 to 10)"),
57
+ ResponseSchema(name="conscientiousness", description="Conscientiousness rating (-10 to 10)"),
58
+ ResponseSchema(name="neuroticism", description="Neuroticism rating (-10 to 10)"),
59
+ ResponseSchema(name="openness", description="Openness rating (-10 to 10)"),
60
+ ResponseSchema(name="explanation", description="Brief explanation of the Big Five traits")
61
+ ]
62
 
63
+ personality_response_schemas = [
64
+ ResponseSchema(name="speaker", description="The name or number of the speaker"),
65
+ ResponseSchema(name="depressed", description="Depressed rating (0-4)"),
66
+ ResponseSchema(name="paranoid", description="Paranoid rating (0-4)"),
67
+ ResponseSchema(name="schizoid_schizotypal", description="Schizoid-Schizotypal rating (0-4)"),
68
+ ResponseSchema(name="antisocial_psychopathic", description="Antisocial-Psychopathic rating (0-4)"),
69
+ ResponseSchema(name="borderline_dysregulated", description="Borderline-Dysregulated rating (0-4)"),
70
+ ResponseSchema(name="narcissistic", description="Narcissistic rating (0-4)"),
71
+ ResponseSchema(name="anxious_avoidant", description="Anxious-Avoidant rating (0-4)"),
72
+ ResponseSchema(name="dependent_victimized", description="Dependent-Victimized rating (0-4)"),
73
+ ResponseSchema(name="obsessional", description="Obsessional rating (0-4)"),
74
+ ResponseSchema(name="explanation", description="Brief explanation of the personality disorders")
75
+ ]
76
 
77
+ attachment_parser = StructuredOutputParser.from_response_schemas(attachment_response_schemas)
78
+ bigfive_parser = StructuredOutputParser.from_response_schemas(bigfive_response_schemas)
79
+ personality_parser = StructuredOutputParser.from_response_schemas(personality_response_schemas)
 
 
80
 
81
+ def get_prompt_template(task: str, parser: StructuredOutputParser) -> PromptTemplate:
82
+ return PromptTemplate(
83
+ template="Analyze the following text according to the given task:\n\n{task}\n\n{format_instructions}\n\nText: {text}\n\nAnalysis:",
84
+ input_variables=["text"],
85
+ partial_variables={
86
+ "task": task,
87
+ "format_instructions": parser.get_format_instructions()
88
+ }
89
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ def parse_analysis_output(output: str, analysis_type: str) -> Dict[str, BaseModel]:
92
+ if analysis_type == "attachments":
93
+ parsed = attachment_parser.parse(output)
94
+ return {parsed['speaker']: AttachmentStyle(**parsed)}
95
+ elif analysis_type == "bigfive":
96
+ parsed = bigfive_parser.parse(output)
97
+ return {parsed['speaker']: BigFiveTraits(**parsed)}
98
+ elif analysis_type == "personalities":
99
+ parsed = personality_parser.parse(output)
100
+ return {parsed['speaker']: PersonalityDisorder(**parsed)}
 
 
 
 
101
  else:
102
+ raise ValueError(f"Unknown analysis type: {analysis_type}")