reab5555 commited on
Commit
9204fa6
·
verified ·
1 Parent(s): 046fb82

Update processing.py

Browse files
Files changed (1) hide show
  1. processing.py +35 -16
processing.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import time
3
  import re
 
4
  from huggingface_hub import login
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from langdetect import detect
8
  from langchain.chains import RetrievalQA
@@ -28,6 +30,7 @@ def load_instructions(file_path):
28
  with open(file_path, 'r') as file:
29
  return file.read().strip()
30
 
 
31
  attachments_task = load_instructions("tasks/Attachments_task.txt")
32
  bigfive_task = load_instructions("tasks/BigFive_task.txt")
33
  personalities_task = load_instructions("tasks/Personalities_task.txt")
@@ -74,7 +77,6 @@ class SequentialAnalyzer:
74
  use_cache=False,
75
  load_in_4bit=False
76
  )
77
- model.gradient_checkpointing_enable()
78
  return model
79
 
80
  def create_pipeline(self, model):
@@ -109,18 +111,30 @@ class SequentialAnalyzer:
109
  print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.")
110
 
111
  llm = HuggingFacePipeline(pipeline=self.pipe)
112
- chain = RetrievalQA.from_chain_type(
113
- llm=llm,
114
- chain_type="stuff",
115
- retriever=knowledge_db.as_retriever(),
116
- chain_type_kwargs={"prompt": PromptTemplate(
117
- template=task + "\n\n{context}\n\n{question}\n\n-----------\n\nAnswer: ",
118
- input_variables=["context", "question"]
119
- )}
120
- )
121
-
122
- result = chain({"query": truncated_content})
123
- output = result['result'].split("-----------\n\nAnswer:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
124
  cleaned_output = self.post_process_output(output)
125
 
126
  return cleaned_output, input_tokens
@@ -161,7 +175,7 @@ def process_input(input_file, progress=None):
161
  transcription = content # Store the transcription
162
  os.remove(srt_path)
163
  else:
164
- return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None, None, None, None, None
165
 
166
  detected_language = detect_language(content)
167
 
@@ -172,6 +186,11 @@ def process_input(input_file, progress=None):
172
 
173
  analyzer = SequentialAnalyzer(hf_token)
174
 
 
 
 
 
 
175
  safe_progress(0.5, desc="Analyzing attachments...")
176
  attachments_answer, attachments_tokens = analyzer.analyze_task(content, attachments_task, attachments_db)
177
  print("Attachments output:\n", attachments_answer)
@@ -195,5 +214,5 @@ def process_input(input_file, progress=None):
195
  safe_progress(1.0, desc="Analysis complete!")
196
 
197
  return ("Analysis complete!", execution_info, detected_language,
198
- attachments_answer, bigfive_answer, personalities_answer,
199
- original_tokens, attachments_tokens, bigfive_tokens, personalities_tokens, transcription)
 
1
  import os
2
  import time
3
  import re
4
+ import numpy as np
5
  from huggingface_hub import login
6
  import torch
7
+ import random
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from langdetect import detect
10
  from langchain.chains import RetrievalQA
 
30
  with open(file_path, 'r') as file:
31
  return file.read().strip()
32
 
33
+ general_task = load_instructions("tasks/general_task.txt")
34
  attachments_task = load_instructions("tasks/Attachments_task.txt")
35
  bigfive_task = load_instructions("tasks/BigFive_task.txt")
36
  personalities_task = load_instructions("tasks/Personalities_task.txt")
 
77
  use_cache=False,
78
  load_in_4bit=False
79
  )
 
80
  return model
81
 
82
  def create_pipeline(self, model):
 
111
  print(f"Warning: Input was truncated from {input_tokens} to {max_input_length} tokens.")
112
 
113
  llm = HuggingFacePipeline(pipeline=self.pipe)
114
+
115
+ if knowledge_db is None:
116
+ # For general task without specific knowledge DB
117
+ prompt = PromptTemplate(
118
+ template=task + "\n\n{question}\n\n-----------\n\nAnswer: ",
119
+ input_variables=["question"]
120
+ )
121
+ chain = prompt | llm
122
+ result = chain.invoke({"question": truncated_content})
123
+ output = result.split("-----------\n\nAnswer:")[-1].strip()
124
+ else:
125
+ # For tasks with specific knowledge DB
126
+ chain = RetrievalQA.from_chain_type(
127
+ llm=llm,
128
+ chain_type="stuff",
129
+ retriever=knowledge_db.as_retriever(),
130
+ chain_type_kwargs={"prompt": PromptTemplate(
131
+ template=task + "\n\n{context}\n\n{question}\n\n-----------\n\nAnswer: ",
132
+ input_variables=["context", "question"]
133
+ )}
134
+ )
135
+ result = chain({"query": truncated_content})
136
+ output = result['result'].split("-----------\n\nAnswer:")[-1].strip()
137
+
138
  cleaned_output = self.post_process_output(output)
139
 
140
  return cleaned_output, input_tokens
 
175
  transcription = content # Store the transcription
176
  os.remove(srt_path)
177
  else:
178
+ return "Unsupported file format. Please upload a TXT, SRT, PDF, or video file.", None, None, None, None, None, None, None, None, None, None, None, None
179
 
180
  detected_language = detect_language(content)
181
 
 
186
 
187
  analyzer = SequentialAnalyzer(hf_token)
188
 
189
+ safe_progress(0.3, desc="Performing general analysis...")
190
+ general_answer, general_tokens = analyzer.analyze_task(content, general_task, None)
191
+ print("General output:\n", general_answer)
192
+ print(f"General input tokens (before truncation): {general_tokens}")
193
+
194
  safe_progress(0.5, desc="Analyzing attachments...")
195
  attachments_answer, attachments_tokens = analyzer.analyze_task(content, attachments_task, attachments_db)
196
  print("Attachments output:\n", attachments_answer)
 
214
  safe_progress(1.0, desc="Analysis complete!")
215
 
216
  return ("Analysis complete!", execution_info, detected_language,
217
+ general_answer, attachments_answer, bigfive_answer, personalities_answer,
218
+ original_tokens, general_tokens, attachments_tokens, bigfive_tokens, personalities_tokens, transcription)