Spaces:
Runtime error
Runtime error
Update processing.py
Browse files- processing.py +91 -222
processing.py
CHANGED
@@ -1,233 +1,102 @@
|
|
1 |
-
import
|
2 |
-
import time
|
3 |
-
import re
|
4 |
-
import numpy as np
|
5 |
-
from huggingface_hub import login
|
6 |
-
import torch
|
7 |
-
import random
|
8 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
-
from langdetect import detect
|
10 |
-
from langchain.chains import RetrievalQA
|
11 |
-
from langchain_community.llms import HuggingFacePipeline
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
-
from
|
14 |
-
from
|
15 |
-
from langchain_community.vectorstores import FAISS
|
16 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
17 |
-
from transcription_diarization import process_video
|
18 |
-
from output_parser import get_prompt_template, attachment_parser, bigfive_parser, personality_parser, parse_analysis_output
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
except:
|
51 |
-
return "en"
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
random.seed(seed)
|
63 |
-
np.random.seed(seed)
|
64 |
-
torch.manual_seed(seed)
|
65 |
-
if torch.cuda.is_available():
|
66 |
-
torch.cuda.manual_seed_all(seed)
|
67 |
-
|
68 |
-
def load_model(self):
|
69 |
-
model = AutoModelForCausalLM.from_pretrained(
|
70 |
-
self.model_name,
|
71 |
-
torch_dtype=torch.bfloat16,
|
72 |
-
device_map="auto",
|
73 |
-
use_auth_token=self.hf_token,
|
74 |
-
use_cache=False,
|
75 |
-
load_in_4bit=False
|
76 |
-
)
|
77 |
-
return model
|
78 |
-
|
79 |
-
def create_pipeline(self, model):
|
80 |
-
from transformers import pipeline
|
81 |
-
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token)
|
82 |
-
return pipeline(
|
83 |
-
"text-generation",
|
84 |
-
model=model,
|
85 |
-
top_k=50,
|
86 |
-
top_p=0.8,
|
87 |
-
tokenizer=tokenizer,
|
88 |
-
max_new_tokens=512,
|
89 |
-
temperature=0.3,
|
90 |
-
repetition_penalty=1.2,
|
91 |
-
do_sample=False,
|
92 |
-
truncation=True,
|
93 |
-
bad_words_ids=[[tokenizer.encode(char, add_special_tokens=False)[0]] for char in "*"]
|
94 |
-
)
|
95 |
-
|
96 |
-
def post_process_output(self, output):
|
97 |
-
return re.sub(r'[*]', '', output).strip()
|
98 |
-
|
99 |
-
def analyze_task(self, content, task, knowledge_db, analysis_type):
|
100 |
-
tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.hf_token)
|
101 |
-
|
102 |
-
input_tokens = len(tokenizer.encode(content))
|
103 |
-
|
104 |
-
max_input_length = 800
|
105 |
-
encoded_input = tokenizer.encode(content, truncation=True, max_length=max_input_length)
|
106 |
-
truncated_content = tokenizer.decode(encoded_input)
|
107 |
-
|
108 |
-
if len(encoded_input) == max_input_length:
|
109 |
-
print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.")
|
110 |
-
|
111 |
-
llm = HuggingFacePipeline(pipeline=self.pipe)
|
112 |
-
|
113 |
-
if analysis_type == "attachments":
|
114 |
-
parser = attachment_parser
|
115 |
-
elif analysis_type == "bigfive":
|
116 |
-
parser = bigfive_parser
|
117 |
-
elif analysis_type == "personalities":
|
118 |
-
parser = personality_parser
|
119 |
-
else:
|
120 |
-
raise ValueError(f"Unknown analysis type: {analysis_type}")
|
121 |
-
|
122 |
-
prompt = get_prompt_template(task, parser)
|
123 |
-
|
124 |
-
if knowledge_db is None:
|
125 |
-
chain = prompt | llm
|
126 |
-
result = chain.invoke({"text": truncated_content})
|
127 |
-
output = result
|
128 |
-
else:
|
129 |
-
chain = RetrievalQA.from_chain_type(
|
130 |
-
llm=llm,
|
131 |
-
chain_type="stuff",
|
132 |
-
retriever=knowledge_db.as_retriever(),
|
133 |
-
chain_type_kwargs={
|
134 |
-
"prompt": PromptTemplate(
|
135 |
-
template=task + "\n\n{context}\n\n{question}\n\n" + parser.get_format_instructions() + "\n\nAnalysis:",
|
136 |
-
input_variables=["context", "question"]
|
137 |
-
)
|
138 |
-
}
|
139 |
-
)
|
140 |
-
result = chain.run(truncated_content)
|
141 |
-
output = result
|
142 |
-
|
143 |
-
print(f"Raw model output: {output}")
|
144 |
-
|
145 |
-
try:
|
146 |
-
cleaned_output = self.post_process_output(output)
|
147 |
-
parsed_output = parser.parse(cleaned_output)
|
148 |
-
except Exception as e:
|
149 |
-
raise ValueError(f"Error parsing output: {e}")
|
150 |
-
|
151 |
-
# Check if all required keys are present
|
152 |
-
required_keys = {schema.name for schema in parser.response_schemas}
|
153 |
-
missing_keys = required_keys - parsed_output.keys()
|
154 |
-
|
155 |
-
if missing_keys:
|
156 |
-
raise ValueError(f"Missing some input keys: {missing_keys}")
|
157 |
-
|
158 |
-
return cleaned_output, input_tokens
|
159 |
|
160 |
-
def
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
safe_progress(0, desc="Processing file")
|
171 |
-
|
172 |
-
if isinstance(input_file, str):
|
173 |
-
file_path = input_file
|
174 |
else:
|
175 |
-
|
176 |
-
|
177 |
-
file_extension = os.path.splitext(file_path)[1].lower()
|
178 |
-
|
179 |
-
if file_extension in ['.txt', '.srt']:
|
180 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
181 |
-
content = file.read()
|
182 |
-
transcription = content
|
183 |
-
elif file_extension == '.pdf':
|
184 |
-
loader = PyPDFLoader(file_path)
|
185 |
-
pages = loader.load_and_split()
|
186 |
-
content = '\n'.join([page.page_content for page in pages])
|
187 |
-
transcription = content
|
188 |
-
elif file_extension in ['.mp4', '.avi', '.mov']:
|
189 |
-
safe_progress(0.2, desc="Processing video...")
|
190 |
-
srt_path = process_video(file_path, hf_token, "en", max_speakers)
|
191 |
-
with open(srt_path, 'r', encoding='utf-8') as file:
|
192 |
-
content = file.read()
|
193 |
-
transcription = content
|
194 |
-
os.remove(srt_path)
|
195 |
-
else:
|
196 |
-
return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None
|
197 |
-
|
198 |
-
detected_language = detect_language(content)
|
199 |
-
|
200 |
-
safe_progress(0.2, desc="Initializing analyzer")
|
201 |
-
analyzer = SequentialAnalyzer(hf_token)
|
202 |
-
|
203 |
-
tasks = [
|
204 |
-
("General + Attachments", general_task + "\n\n" + attachments_task, attachments_db, "attachments"),
|
205 |
-
("General + Big Five", general_task + "\n\n" + bigfive_task, bigfive_db, "bigfive"),
|
206 |
-
("General + Personalities", general_task + "\n\n" + personalities_task, personalities_db, "personalities")
|
207 |
-
]
|
208 |
-
|
209 |
-
results = []
|
210 |
-
tokens = []
|
211 |
-
|
212 |
-
for i, (task_name, task, db, analysis_type) in enumerate(tasks):
|
213 |
-
safe_progress((i + 1) * 0.2, desc=f"Analyzing {task_name}")
|
214 |
-
answer, task_tokens = analyzer.analyze_task(content, task, db, analysis_type)
|
215 |
-
results.append((answer, analysis_type))
|
216 |
-
tokens.append(task_tokens)
|
217 |
-
|
218 |
-
end_time = time.time()
|
219 |
-
execution_time = end_time - start_time
|
220 |
-
|
221 |
-
safe_progress(1.0, desc="Analysis complete!")
|
222 |
-
|
223 |
-
parsed_results = [parse_analysis_output(result, analysis_type) for result, analysis_type in results]
|
224 |
-
|
225 |
-
return (
|
226 |
-
"Analysis complete!",
|
227 |
-
f"{execution_time:.2f} seconds",
|
228 |
-
detected_language,
|
229 |
-
parsed_results[0], # attachments
|
230 |
-
parsed_results[1], # bigfive
|
231 |
-
parsed_results[2], # personalities,
|
232 |
-
transcription
|
233 |
-
)
|
|
|
1 |
+
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langchain.prompts import PromptTemplate
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import Dict
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
class AttachmentStyle(BaseModel):
|
7 |
+
speaker: str
|
8 |
+
secured: float
|
9 |
+
anxious_preoccupied: float
|
10 |
+
dismissive_avoidant: float
|
11 |
+
fearful_avoidant: float
|
12 |
+
self_rating: int
|
13 |
+
others_rating: int
|
14 |
+
anxiety: int
|
15 |
+
avoidance: int
|
16 |
+
explanation: str
|
17 |
|
18 |
+
class BigFiveTraits(BaseModel):
|
19 |
+
speaker: str
|
20 |
+
extraversion: int
|
21 |
+
agreeableness: int
|
22 |
+
conscientiousness: int
|
23 |
+
neuroticism: int
|
24 |
+
openness: int
|
25 |
+
explanation: str
|
26 |
|
27 |
+
class PersonalityDisorder(BaseModel):
|
28 |
+
speaker: str
|
29 |
+
depressed: int
|
30 |
+
paranoid: int
|
31 |
+
schizoid_schizotypal: int
|
32 |
+
antisocial_psychopathic: int
|
33 |
+
borderline_dysregulated: int
|
34 |
+
narcissistic: int
|
35 |
+
anxious_avoidant: int
|
36 |
+
dependent_victimized: int
|
37 |
+
obsessional: int
|
38 |
+
explanation: str
|
39 |
|
40 |
+
attachment_response_schemas = [
|
41 |
+
ResponseSchema(name="speaker", description="The name or number of the speaker"),
|
42 |
+
ResponseSchema(name="secured", description="Probability of secured attachment style (0-1)"),
|
43 |
+
ResponseSchema(name="anxious_preoccupied", description="Probability of anxious-preoccupied attachment style (0-1)"),
|
44 |
+
ResponseSchema(name="dismissive_avoidant", description="Probability of dismissive-avoidant attachment style (0-1)"),
|
45 |
+
ResponseSchema(name="fearful_avoidant", description="Probability of fearful-avoidant attachment style (0-1)"),
|
46 |
+
ResponseSchema(name="self_rating", description="Self rating (0-10)"),
|
47 |
+
ResponseSchema(name="others_rating", description="Others rating (0-10)"),
|
48 |
+
ResponseSchema(name="anxiety", description="Anxiety rating (0-10)"),
|
49 |
+
ResponseSchema(name="avoidance", description="Avoidance rating (0-10)"),
|
50 |
+
ResponseSchema(name="explanation", description="Brief explanation of the attachment style")
|
51 |
+
]
|
52 |
|
53 |
+
bigfive_response_schemas = [
|
54 |
+
ResponseSchema(name="speaker", description="The name or number of the speaker"),
|
55 |
+
ResponseSchema(name="extraversion", description="Extraversion rating (-10 to 10)"),
|
56 |
+
ResponseSchema(name="agreeableness", description="Agreeableness rating (-10 to 10)"),
|
57 |
+
ResponseSchema(name="conscientiousness", description="Conscientiousness rating (-10 to 10)"),
|
58 |
+
ResponseSchema(name="neuroticism", description="Neuroticism rating (-10 to 10)"),
|
59 |
+
ResponseSchema(name="openness", description="Openness rating (-10 to 10)"),
|
60 |
+
ResponseSchema(name="explanation", description="Brief explanation of the Big Five traits")
|
61 |
+
]
|
62 |
|
63 |
+
personality_response_schemas = [
|
64 |
+
ResponseSchema(name="speaker", description="The name or number of the speaker"),
|
65 |
+
ResponseSchema(name="depressed", description="Depressed rating (0-4)"),
|
66 |
+
ResponseSchema(name="paranoid", description="Paranoid rating (0-4)"),
|
67 |
+
ResponseSchema(name="schizoid_schizotypal", description="Schizoid-Schizotypal rating (0-4)"),
|
68 |
+
ResponseSchema(name="antisocial_psychopathic", description="Antisocial-Psychopathic rating (0-4)"),
|
69 |
+
ResponseSchema(name="borderline_dysregulated", description="Borderline-Dysregulated rating (0-4)"),
|
70 |
+
ResponseSchema(name="narcissistic", description="Narcissistic rating (0-4)"),
|
71 |
+
ResponseSchema(name="anxious_avoidant", description="Anxious-Avoidant rating (0-4)"),
|
72 |
+
ResponseSchema(name="dependent_victimized", description="Dependent-Victimized rating (0-4)"),
|
73 |
+
ResponseSchema(name="obsessional", description="Obsessional rating (0-4)"),
|
74 |
+
ResponseSchema(name="explanation", description="Brief explanation of the personality disorders")
|
75 |
+
]
|
76 |
|
77 |
+
attachment_parser = StructuredOutputParser.from_response_schemas(attachment_response_schemas)
|
78 |
+
bigfive_parser = StructuredOutputParser.from_response_schemas(bigfive_response_schemas)
|
79 |
+
personality_parser = StructuredOutputParser.from_response_schemas(personality_response_schemas)
|
|
|
|
|
80 |
|
81 |
+
def get_prompt_template(task: str, parser: StructuredOutputParser) -> PromptTemplate:
|
82 |
+
return PromptTemplate(
|
83 |
+
template="Analyze the following text according to the given task:\n\n{task}\n\n{format_instructions}\n\nText: {text}\n\nAnalysis:",
|
84 |
+
input_variables=["text"],
|
85 |
+
partial_variables={
|
86 |
+
"task": task,
|
87 |
+
"format_instructions": parser.get_format_instructions()
|
88 |
+
}
|
89 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
def parse_analysis_output(output: str, analysis_type: str) -> Dict[str, BaseModel]:
|
92 |
+
if analysis_type == "attachments":
|
93 |
+
parsed = attachment_parser.parse(output)
|
94 |
+
return {parsed['speaker']: AttachmentStyle(**parsed)}
|
95 |
+
elif analysis_type == "bigfive":
|
96 |
+
parsed = bigfive_parser.parse(output)
|
97 |
+
return {parsed['speaker']: BigFiveTraits(**parsed)}
|
98 |
+
elif analysis_type == "personalities":
|
99 |
+
parsed = personality_parser.parse(output)
|
100 |
+
return {parsed['speaker']: PersonalityDisorder(**parsed)}
|
|
|
|
|
|
|
|
|
101 |
else:
|
102 |
+
raise ValueError(f"Unknown analysis type: {analysis_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|