Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -33,32 +33,87 @@ class SyntheticDataFactory:
|
|
33 |
|
34 |
def init_session_state(self):
|
35 |
if 'qa_data' not in st.session_state:
|
36 |
-
st.session_state.qa_data =
|
37 |
-
'pairs': [],
|
38 |
-
'metadata': {},
|
39 |
-
'exports': {}
|
40 |
-
}
|
41 |
if 'processing' not in st.session_state:
|
42 |
st.session_state.processing = {
|
43 |
'stage': 'idle',
|
44 |
-
'errors': []
|
|
|
45 |
}
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
def
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def main():
|
64 |
st.set_page_config(
|
@@ -67,21 +122,57 @@ def main():
|
|
67 |
layout="wide"
|
68 |
)
|
69 |
|
70 |
-
# Initialize factory instance
|
71 |
factory = SyntheticDataFactory()
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
77 |
|
78 |
-
# File upload and processing logic
|
79 |
uploaded_file = st.file_uploader("Upload Financial PDF", type=["pdf"])
|
80 |
|
81 |
-
if uploaded_file and api_key:
|
82 |
-
|
83 |
-
# Process
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
if __name__ == "__main__":
|
87 |
main()
|
|
|
33 |
|
34 |
def init_session_state(self):
|
35 |
if 'qa_data' not in st.session_state:
|
36 |
+
st.session_state.qa_data = []
|
|
|
|
|
|
|
|
|
37 |
if 'processing' not in st.session_state:
|
38 |
st.session_state.processing = {
|
39 |
'stage': 'idle',
|
40 |
+
'errors': [],
|
41 |
+
'progress': 0
|
42 |
}
|
43 |
|
44 |
+
def process_pdf(self, file):
|
45 |
+
"""Process PDF with error handling"""
|
46 |
+
try:
|
47 |
+
with pdfplumber.open(file) as pdf:
|
48 |
+
pages = pdf.pages
|
49 |
+
for i, page in enumerate(pages):
|
50 |
+
# Update progress
|
51 |
+
st.session_state.processing['progress'] = (i+1)/len(pages)
|
52 |
+
|
53 |
+
# Process page content
|
54 |
+
text = page.extract_text() or ""
|
55 |
+
images = self.process_images(page)
|
56 |
+
|
57 |
+
# Store in session state
|
58 |
+
st.session_state.qa_data.append({
|
59 |
+
"page": i+1,
|
60 |
+
"text": text,
|
61 |
+
"images": images
|
62 |
+
})
|
63 |
+
time.sleep(0.1) # Simulate processing
|
64 |
+
return True
|
65 |
+
except Exception as e:
|
66 |
+
st.error(f"PDF processing failed: {str(e)}")
|
67 |
+
return False
|
68 |
|
69 |
+
def process_images(self, page):
|
70 |
+
"""Robust image processing"""
|
71 |
+
images = []
|
72 |
+
for img in page.images:
|
73 |
+
try:
|
74 |
+
# Handle different PDF image formats
|
75 |
+
stream = img['stream']
|
76 |
+
width = int(stream.get('Width', stream.get('W', 0)))
|
77 |
+
height = int(stream.get('Height', stream.get('H', 0)))
|
78 |
+
|
79 |
+
if width > 0 and height > 0:
|
80 |
+
image = Image.frombytes(
|
81 |
+
"RGB" if 'ColorSpace' in stream else "L",
|
82 |
+
(width, height),
|
83 |
+
stream.get_data()
|
84 |
+
)
|
85 |
+
images.append(image)
|
86 |
+
except Exception as e:
|
87 |
+
st.warning(f"Image processing error: {str(e)[:100]}")
|
88 |
+
return images
|
89 |
+
|
90 |
+
def generate_qa(self, provider, api_key, model, temp):
|
91 |
+
"""Generate Q&A pairs with selected provider"""
|
92 |
+
try:
|
93 |
+
client = self.PROVIDER_CONFIG[provider]["client"](api_key)
|
94 |
+
|
95 |
+
for item in st.session_state.qa_data:
|
96 |
+
prompt = f"Generate 3 Q&A pairs from this financial content:\n{item['text']}\nOutput JSON format with keys: question, answer_1, answer_2"
|
97 |
+
|
98 |
+
response = client.chat.completions.create(
|
99 |
+
model=model,
|
100 |
+
messages=[{"role": "user", "content": prompt}],
|
101 |
+
temperature=temp,
|
102 |
+
response_format={"type": "json_object"}
|
103 |
+
)
|
104 |
+
|
105 |
+
try:
|
106 |
+
result = json.loads(response.choices[0].message.content)
|
107 |
+
item["qa_pairs"] = result.get("qa_pairs", [])
|
108 |
+
except json.JSONDecodeError:
|
109 |
+
st.error("Failed to parse AI response")
|
110 |
+
|
111 |
+
st.session_state.processing['stage'] = 'complete'
|
112 |
+
return True
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
st.error(f"Generation failed: {str(e)}")
|
116 |
+
return False
|
117 |
|
118 |
def main():
|
119 |
st.set_page_config(
|
|
|
122 |
layout="wide"
|
123 |
)
|
124 |
|
|
|
125 |
factory = SyntheticDataFactory()
|
126 |
|
127 |
+
# Sidebar Configuration
|
128 |
+
with st.sidebar:
|
129 |
+
st.header("⚙️ AI Configuration")
|
130 |
+
provider = st.selectbox("Provider", list(factory.PROVIDER_CONFIG.keys()))
|
131 |
+
config = factory.PROVIDER_CONFIG[provider]
|
132 |
+
api_key = st.text_input(config["key_label"], type="password")
|
133 |
+
model = st.selectbox("Model", config["models"])
|
134 |
+
temp = st.slider("Temperature", 0.0, 1.0, 0.3)
|
135 |
+
|
136 |
+
# Main Interface
|
137 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
138 |
|
|
|
139 |
uploaded_file = st.file_uploader("Upload Financial PDF", type=["pdf"])
|
140 |
|
141 |
+
if uploaded_file and api_key and st.button("Start Synthetic Generation"):
|
142 |
+
with st.status("Processing document...", expanded=True) as status:
|
143 |
+
# Process PDF
|
144 |
+
st.write("Extracting text and images...")
|
145 |
+
if factory.process_pdf(uploaded_file):
|
146 |
+
# Generate Q&A pairs
|
147 |
+
st.write("Generating synthetic data...")
|
148 |
+
if factory.generate_qa(provider, api_key, model, temp):
|
149 |
+
status.update(label="Processing complete!", state="complete", expanded=False)
|
150 |
+
|
151 |
+
# Display Results
|
152 |
+
if st.session_state.processing.get('stage') == 'complete':
|
153 |
+
st.subheader("Generated Q&A Pairs")
|
154 |
+
|
155 |
+
# Convert to DataFrame
|
156 |
+
all_qa = []
|
157 |
+
for item in st.session_state.qa_data:
|
158 |
+
for qa in item.get("qa_pairs", []):
|
159 |
+
qa["page"] = item["page"]
|
160 |
+
all_qa.append(qa)
|
161 |
+
|
162 |
+
if len(all_qa) > 0:
|
163 |
+
df = pd.DataFrame(all_qa)
|
164 |
+
st.dataframe(df)
|
165 |
+
|
166 |
+
# Export options
|
167 |
+
csv = df.to_csv(index=False).encode('utf-8')
|
168 |
+
st.download_button(
|
169 |
+
label="Download as CSV",
|
170 |
+
data=csv,
|
171 |
+
file_name="synthetic_data.csv",
|
172 |
+
mime="text/csv"
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
st.warning("No Q&A pairs generated. Check your document content and API settings.")
|
176 |
|
177 |
if __name__ == "__main__":
|
178 |
main()
|