Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,156 +1,145 @@
|
|
1 |
import streamlit as st
|
2 |
import pdfplumber
|
3 |
-
import pytesseract
|
4 |
import pandas as pd
|
5 |
import requests
|
6 |
import json
|
7 |
from PIL import Image
|
8 |
-
from io import BytesIO
|
9 |
from openai import OpenAI
|
10 |
-
import google.
|
11 |
import groq
|
12 |
import sqlalchemy
|
13 |
from typing import Dict, Any
|
14 |
|
15 |
-
#
|
16 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
17 |
-
DEFAULT_TEMPERATURE = 0.1
|
18 |
-
MODEL = "mixtral-8x7b-32768"
|
|
|
|
|
19 |
|
20 |
class SyntheticDataGenerator:
|
21 |
-
"""
|
22 |
-
A class to generate synthetic Q&A data from various input sources using different LLM providers.
|
23 |
-
"""
|
24 |
|
25 |
def __init__(self):
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
self.providers = {
|
28 |
"Deepseek": {
|
29 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
30 |
-
"models": ["deepseek-chat"]
|
31 |
},
|
32 |
"OpenAI": {
|
33 |
"client": lambda key: OpenAI(api_key=key),
|
34 |
-
"models": ["gpt-4-turbo"]
|
35 |
},
|
36 |
"Groq": {
|
37 |
"client": lambda key: groq.Groq(api_key=key),
|
38 |
-
"models": [MODEL]
|
39 |
},
|
40 |
"HuggingFace": {
|
41 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
42 |
-
"models": ["gpt2", "llama-2"]
|
43 |
},
|
44 |
-
|
45 |
-
"client": lambda key: self._configure_google_genai(key),
|
46 |
-
"models": ["gemini-pro"]
|
47 |
},
|
48 |
}
|
49 |
|
|
|
|
|
50 |
self.input_handlers = {
|
51 |
"pdf": self.handle_pdf,
|
52 |
"text": self.handle_text,
|
53 |
"csv": self.handle_csv,
|
54 |
"api": self.handle_api,
|
55 |
-
"db": self.handle_db
|
56 |
}
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def _configure_google_genai(self, api_key: str):
|
61 |
"""Configures the Google Generative AI client."""
|
62 |
try:
|
63 |
genai.configure(api_key=api_key)
|
64 |
-
return genai.GenerativeModel
|
65 |
except Exception as e:
|
66 |
st.error(f"Error configuring Google GenAI: {e}")
|
67 |
-
return None
|
68 |
-
|
69 |
-
def init_session(self):
|
70 |
-
"""Initializes the Streamlit session state with default values."""
|
71 |
-
session_defaults = {
|
72 |
-
'inputs': [],
|
73 |
-
'qa_data': [],
|
74 |
-
'processing': {
|
75 |
-
'stage': 'idle',
|
76 |
-
'progress': 0,
|
77 |
-
'errors': []
|
78 |
-
},
|
79 |
-
'config': {
|
80 |
-
'provider': "Groq",
|
81 |
-
'model': MODEL,
|
82 |
-
'temperature': DEFAULT_TEMPERATURE
|
83 |
-
}
|
84 |
-
}
|
85 |
-
|
86 |
-
for key, val in session_defaults.items():
|
87 |
-
if key not in st.session_state:
|
88 |
-
st.session_state[key] = val
|
89 |
|
90 |
-
#
|
91 |
def handle_pdf(self, file):
|
92 |
-
|
93 |
-
|
94 |
with pdfplumber.open(file) as pdf:
|
95 |
extracted_data = []
|
96 |
for i, page in enumerate(pdf.pages):
|
97 |
page_text = page.extract_text() or ""
|
98 |
page_images = self.process_images(page)
|
99 |
-
extracted_data.append(
|
100 |
-
"text": page_text,
|
101 |
-
|
102 |
-
"meta": {"type": "pdf", "page": i + 1}
|
103 |
-
})
|
104 |
return extracted_data
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
|
109 |
def handle_text(self, text):
|
110 |
"""Handles manual text input."""
|
111 |
-
return [{
|
112 |
-
"text": text,
|
113 |
-
"meta": {"type": "domain", "source": "manual"}
|
114 |
-
}]
|
115 |
|
116 |
def handle_csv(self, file):
|
117 |
"""Reads a CSV file and prepares data for Q&A generation."""
|
118 |
try:
|
119 |
df = pd.read_csv(file)
|
120 |
-
return [
|
121 |
-
"text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
|
122 |
-
|
123 |
-
|
124 |
except Exception as e:
|
125 |
-
self.
|
126 |
return []
|
127 |
|
128 |
def handle_api(self, config):
|
129 |
"""Fetches data from an API endpoint."""
|
130 |
try:
|
131 |
-
response = requests.get(config[
|
132 |
-
response.raise_for_status() # Raise HTTPError for bad responses
|
133 |
-
return [{
|
134 |
-
"text": json.dumps(response.json()),
|
135 |
-
"meta": {"type": "api", "endpoint": config['url']}
|
136 |
-
}]
|
137 |
except requests.exceptions.RequestException as e:
|
138 |
-
self.
|
139 |
return []
|
140 |
|
141 |
-
|
142 |
def handle_db(self, config):
|
143 |
"""Connects to a database and executes a query."""
|
144 |
try:
|
145 |
-
engine = sqlalchemy.create_engine(config[
|
146 |
with engine.connect() as conn:
|
147 |
-
result = conn.execute(sqlalchemy.text(config[
|
148 |
-
return [
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
152 |
except Exception as e:
|
153 |
-
self.
|
154 |
return []
|
155 |
|
156 |
def process_images(self, page):
|
@@ -158,211 +147,206 @@ class SyntheticDataGenerator:
|
|
158 |
images = []
|
159 |
for img in page.images:
|
160 |
try:
|
161 |
-
stream = img[
|
162 |
-
width = int(stream.get(
|
163 |
-
height = int(stream.get(
|
164 |
-
image_data = stream.get_data()
|
165 |
-
|
|
|
166 |
try:
|
167 |
image = Image.frombytes("RGB", (width, height), image_data)
|
168 |
-
images.append({
|
169 |
-
"data": image,
|
170 |
-
"meta": {"dims": (width, height)}
|
171 |
-
})
|
172 |
except Exception as e:
|
173 |
-
self.
|
174 |
else:
|
175 |
-
self.
|
176 |
-
|
|
|
177 |
|
178 |
except Exception as e:
|
179 |
-
self.
|
180 |
return images
|
181 |
|
182 |
-
#
|
183 |
def generate(self, api_key: str) -> bool:
|
184 |
-
"""
|
185 |
-
Generates Q&A pairs using the selected LLM provider.
|
186 |
-
|
187 |
-
Args:
|
188 |
-
api_key (str): The API key for the selected LLM provider.
|
189 |
-
|
190 |
-
Returns:
|
191 |
-
bool: True if generation was successful, False otherwise.
|
192 |
-
"""
|
193 |
try:
|
194 |
-
provider_cfg = self.providers[st.session_state.config['provider']]
|
195 |
-
client_initializer = provider_cfg["client"] #Get the client init function.
|
196 |
-
|
197 |
-
# Check that the key is not an empty string
|
198 |
if not api_key:
|
199 |
st.error("API Key cannot be empty.")
|
200 |
return False
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
205 |
if not client:
|
206 |
return False # Google config failed
|
207 |
else:
|
208 |
client = client_initializer(api_key)
|
209 |
|
210 |
for i, input_data in enumerate(st.session_state.inputs):
|
|
|
211 |
|
212 |
-
|
|
|
|
|
213 |
|
214 |
-
if st.session_state.config[
|
215 |
response = self._huggingface_inference(client, input_data)
|
216 |
-
elif st.session_state.config[
|
217 |
-
|
218 |
else:
|
219 |
response = self._standard_inference(client, input_data)
|
220 |
|
221 |
if response:
|
222 |
-
#
|
223 |
-
st.
|
|
|
|
|
|
|
224 |
|
225 |
return True
|
|
|
226 |
except Exception as e:
|
227 |
-
self.
|
228 |
return False
|
229 |
|
230 |
def _standard_inference(self, client, input_data):
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
#st.write(input_data['text']) # debugging data
|
235 |
return client.chat.completions.create(
|
236 |
-
model=st.session_state.config[
|
237 |
-
messages=[{
|
238 |
-
|
239 |
-
"content": self._build_prompt(input_data)
|
240 |
-
}],
|
241 |
-
temperature=st.session_state.config['temperature'],
|
242 |
-
response_format={"type": "json_object"} #Request json
|
243 |
)
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
|
248 |
def _huggingface_inference(self, client, input_data):
|
249 |
"""Performs inference using Hugging Face Inference API."""
|
250 |
try:
|
251 |
response = requests.post(
|
252 |
-
HF_API_URL + st.session_state.config[
|
253 |
headers=client["headers"],
|
254 |
-
json={"inputs": self._build_prompt(input_data)}
|
255 |
)
|
256 |
-
response.raise_for_status()
|
257 |
return response.json()
|
258 |
except requests.exceptions.RequestException as e:
|
259 |
-
self.
|
260 |
return None
|
261 |
|
262 |
def _google_inference(self, client, input_data):
|
263 |
"""Performs inference using Google Generative AI API."""
|
264 |
try:
|
265 |
-
model = client(st.session_state.config[
|
266 |
response = model.generate_content(
|
267 |
self._build_prompt(input_data),
|
268 |
-
generation_config
|
269 |
-
|
270 |
)
|
271 |
-
|
272 |
-
st.write("Google API Response:") # Debugging: Print the raw response
|
273 |
-
st.write(response.text)
|
274 |
-
|
275 |
return response
|
276 |
except Exception as e:
|
277 |
-
self.
|
278 |
return None
|
279 |
|
|
|
280 |
def _build_prompt(self, input_data):
|
281 |
"""Builds the prompt for the LLM based on the input data type."""
|
282 |
-
base =
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
try:
|
|
|
|
|
296 |
if provider == "HuggingFace":
|
297 |
-
|
|
|
298 |
elif provider == "Google":
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
qa_pairs = json.loads(json_string).get("qa_pairs", []) # Extract the qa_pairs
|
303 |
-
|
304 |
-
# Validate the structure of qa_pairs
|
305 |
-
if not isinstance(qa_pairs, list):
|
306 |
-
raise ValueError("Expected a list of QA pairs.")
|
307 |
-
|
308 |
-
for pair in qa_pairs:
|
309 |
-
if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
|
310 |
-
raise ValueError("Each item in the list must be a dictionary with 'question' and 'answer' keys.")
|
311 |
-
return qa_pairs # Return the extracted and validated list
|
312 |
-
except (json.JSONDecodeError, ValueError) as e:
|
313 |
-
self.log_error(f"Google JSON Parse Error: {e}. Raw Response: {response.text}")
|
314 |
-
return [] # Return empty in case of parsing failure
|
315 |
-
else:
|
316 |
-
# Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
|
317 |
if not response or not response.choices or not response.choices[0].message.content:
|
318 |
-
self.
|
319 |
return []
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
except Exception as e:
|
328 |
-
self.
|
329 |
return []
|
330 |
|
331 |
-
def
|
332 |
-
"""Logs an error message to
|
333 |
-
st.session_state.processing[
|
334 |
st.error(message)
|
335 |
|
336 |
-
# Streamlit UI Components
|
337 |
-
def input_sidebar(gen: SyntheticDataGenerator):
|
338 |
-
"""
|
339 |
-
Creates the input sidebar in the Streamlit UI.
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
Returns:
|
345 |
-
str: The API key entered by the user.
|
346 |
-
"""
|
347 |
with st.sidebar:
|
348 |
st.header("⚙️ Configuration")
|
349 |
|
350 |
-
# AI Provider Settings
|
351 |
provider = st.selectbox("Provider", list(gen.providers.keys()))
|
|
|
352 |
provider_cfg = gen.providers[provider]
|
353 |
|
354 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
355 |
-
st.session_state[
|
356 |
|
357 |
model = st.selectbox("Model", provider_cfg["models"])
|
358 |
-
|
359 |
|
360 |
-
|
361 |
-
st.session_state.config
|
362 |
-
"provider": provider,
|
363 |
-
"model": model,
|
364 |
-
"temperature": temp
|
365 |
-
})
|
366 |
|
367 |
# Input Source Selection
|
368 |
st.header("🔗 Data Sources")
|
@@ -376,11 +360,11 @@ def input_sidebar(gen: SyntheticDataGenerator):
|
|
376 |
elif input_type == "csv":
|
377 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
378 |
if csv_file:
|
379 |
-
|
380 |
|
381 |
elif input_type == "api":
|
382 |
api_url = st.text_input("API Endpoint")
|
383 |
-
api_headers = st.text_area("API Headers (JSON format, optional)", height=
|
384 |
headers = {}
|
385 |
try:
|
386 |
if api_headers:
|
@@ -395,47 +379,38 @@ def input_sidebar(gen: SyntheticDataGenerator):
|
|
395 |
db_query = st.text_area("Database Query")
|
396 |
db_table = st.text_input("Table Name (optional)")
|
397 |
if st.button("Add DB Input"):
|
398 |
-
|
|
|
|
|
399 |
|
400 |
return api_key
|
401 |
|
402 |
-
def main_display(gen: SyntheticDataGenerator):
|
403 |
-
"""
|
404 |
-
Creates the main display area in the Streamlit UI.
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
"""
|
409 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
410 |
|
411 |
-
# Input Processing
|
412 |
col1, col2 = st.columns([3, 1])
|
413 |
with col1:
|
414 |
pdf_file = st.file_uploader("Upload Document", type=["pdf"])
|
415 |
if pdf_file:
|
416 |
-
|
417 |
|
418 |
-
# Generation Controls
|
419 |
with col2:
|
420 |
if st.button("Start Generation"):
|
421 |
with st.status("Processing..."):
|
422 |
-
if not st.session_state
|
423 |
-
|
424 |
else:
|
425 |
-
gen.generate(st.session_state
|
426 |
|
427 |
-
# Results Display
|
428 |
if st.session_state.qa_data:
|
429 |
st.header("Generated Data")
|
430 |
df = pd.DataFrame(st.session_state.qa_data)
|
431 |
st.dataframe(df)
|
432 |
|
433 |
-
|
434 |
-
|
435 |
-
"Export CSV",
|
436 |
-
df.to_csv(index=False),
|
437 |
-
"synthetic_data.csv"
|
438 |
-
)
|
439 |
|
440 |
def main():
|
441 |
"""Main function to run the Streamlit application."""
|
@@ -443,5 +418,6 @@ def main():
|
|
443 |
api_key = input_sidebar(gen)
|
444 |
main_display(gen)
|
445 |
|
|
|
446 |
if __name__ == "__main__":
|
447 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import pdfplumber
|
|
|
3 |
import pandas as pd
|
4 |
import requests
|
5 |
import json
|
6 |
from PIL import Image
|
|
|
7 |
from openai import OpenAI
|
8 |
+
import google.generative_ai as genai
|
9 |
import groq
|
10 |
import sqlalchemy
|
11 |
from typing import Dict, Any
|
12 |
|
13 |
+
# --- CONSTANTS ---
|
14 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
15 |
+
DEFAULT_TEMPERATURE = 0.1
|
16 |
+
MODEL = "mixtral-8x7b-32768" # Groq model
|
17 |
+
API_HEADERS_HEIGHT = 70 # Minimum height for st.text_area
|
18 |
+
|
19 |
|
20 |
class SyntheticDataGenerator:
|
21 |
+
"""Generates synthetic Q&A data from various input sources using LLMs."""
|
|
|
|
|
22 |
|
23 |
def __init__(self):
|
24 |
+
self._setup_providers()
|
25 |
+
self._setup_input_handlers()
|
26 |
+
self._initialize_session_state()
|
27 |
+
|
28 |
+
def _setup_providers(self):
|
29 |
+
"""Defines the available LLM providers and their configurations."""
|
30 |
self.providers = {
|
31 |
"Deepseek": {
|
32 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
33 |
+
"models": ["deepseek-chat"],
|
34 |
},
|
35 |
"OpenAI": {
|
36 |
"client": lambda key: OpenAI(api_key=key),
|
37 |
+
"models": ["gpt-4-turbo"],
|
38 |
},
|
39 |
"Groq": {
|
40 |
"client": lambda key: groq.Groq(api_key=key),
|
41 |
+
"models": [MODEL],
|
42 |
},
|
43 |
"HuggingFace": {
|
44 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
45 |
+
"models": ["gpt2", "llama-2"],
|
46 |
},
|
47 |
+
"Google": {
|
48 |
+
"client": lambda key: self._configure_google_genai(key),
|
49 |
+
"models": ["gemini-pro"],
|
50 |
},
|
51 |
}
|
52 |
|
53 |
+
def _setup_input_handlers(self):
|
54 |
+
"""Defines handlers for different input data types."""
|
55 |
self.input_handlers = {
|
56 |
"pdf": self.handle_pdf,
|
57 |
"text": self.handle_text,
|
58 |
"csv": self.handle_csv,
|
59 |
"api": self.handle_api,
|
60 |
+
"db": self.handle_db,
|
61 |
}
|
62 |
|
63 |
+
def _initialize_session_state(self):
|
64 |
+
"""Initializes Streamlit session state variables."""
|
65 |
+
session_defaults = {
|
66 |
+
"inputs": [],
|
67 |
+
"qa_data": [],
|
68 |
+
"processing": {"stage": "idle", "progress": 0, "errors": []},
|
69 |
+
"config": {"provider": "Groq", "model": MODEL, "temperature": DEFAULT_TEMPERATURE},
|
70 |
+
"api_key": "", # Explicitly initialize api_key in session state
|
71 |
+
}
|
72 |
+
for key, value in session_defaults.items():
|
73 |
+
if key not in st.session_state:
|
74 |
+
st.session_state[key] = value
|
75 |
|
76 |
def _configure_google_genai(self, api_key: str):
|
77 |
"""Configures the Google Generative AI client."""
|
78 |
try:
|
79 |
genai.configure(api_key=api_key)
|
80 |
+
return genai.GenerativeModel
|
81 |
except Exception as e:
|
82 |
st.error(f"Error configuring Google GenAI: {e}")
|
83 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
# --- INPUT HANDLERS ---
|
86 |
def handle_pdf(self, file):
|
87 |
+
"""Extracts text and images from a PDF file."""
|
88 |
+
try:
|
89 |
with pdfplumber.open(file) as pdf:
|
90 |
extracted_data = []
|
91 |
for i, page in enumerate(pdf.pages):
|
92 |
page_text = page.extract_text() or ""
|
93 |
page_images = self.process_images(page)
|
94 |
+
extracted_data.append(
|
95 |
+
{"text": page_text, "images": page_images, "meta": {"type": "pdf", "page": i + 1}}
|
96 |
+
)
|
|
|
|
|
97 |
return extracted_data
|
98 |
+
except Exception as e:
|
99 |
+
self._log_error(f"PDF Error: {str(e)}")
|
100 |
+
return []
|
101 |
|
102 |
def handle_text(self, text):
|
103 |
"""Handles manual text input."""
|
104 |
+
return [{"text": text, "meta": {"type": "domain", "source": "manual"}}]
|
|
|
|
|
|
|
105 |
|
106 |
def handle_csv(self, file):
|
107 |
"""Reads a CSV file and prepares data for Q&A generation."""
|
108 |
try:
|
109 |
df = pd.read_csv(file)
|
110 |
+
return [
|
111 |
+
{"text": "\n".join([f"{col}: {row[col]}" for col in df.columns]), "meta": {"type": "csv", "columns": list(df.columns)}}
|
112 |
+
for _, row in df.iterrows()
|
113 |
+
]
|
114 |
except Exception as e:
|
115 |
+
self._log_error(f"CSV Error: {str(e)}")
|
116 |
return []
|
117 |
|
118 |
def handle_api(self, config):
|
119 |
"""Fetches data from an API endpoint."""
|
120 |
try:
|
121 |
+
response = requests.get(config["url"], headers=config["headers"], timeout=10) # Add timeout
|
122 |
+
response.raise_for_status() # Raise HTTPError for bad responses
|
123 |
+
return [{"text": json.dumps(response.json()), "meta": {"type": "api", "endpoint": config["url"]}}]
|
|
|
|
|
|
|
124 |
except requests.exceptions.RequestException as e:
|
125 |
+
self._log_error(f"API Error: {str(e)}")
|
126 |
return []
|
127 |
|
|
|
128 |
def handle_db(self, config):
|
129 |
"""Connects to a database and executes a query."""
|
130 |
try:
|
131 |
+
engine = sqlalchemy.create_engine(config["connection"])
|
132 |
with engine.connect() as conn:
|
133 |
+
result = conn.execute(sqlalchemy.text(config["query"]))
|
134 |
+
return [
|
135 |
+
{
|
136 |
+
"text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
|
137 |
+
"meta": {"type": "db", "table": config.get("table", "")},
|
138 |
+
}
|
139 |
+
for row in result
|
140 |
+
]
|
141 |
except Exception as e:
|
142 |
+
self._log_error(f"DB Error: {str(e)}")
|
143 |
return []
|
144 |
|
145 |
def process_images(self, page):
|
|
|
147 |
images = []
|
148 |
for img in page.images:
|
149 |
try:
|
150 |
+
stream = img["stream"]
|
151 |
+
width = int(stream.get("Width", 0))
|
152 |
+
height = int(stream.get("Height", 0))
|
153 |
+
image_data = stream.get_data()
|
154 |
+
|
155 |
+
if width > 0 and height > 0 and image_data:
|
156 |
try:
|
157 |
image = Image.frombytes("RGB", (width, height), image_data)
|
158 |
+
images.append({"data": image, "meta": {"dims": (width, height)}})
|
|
|
|
|
|
|
159 |
except Exception as e:
|
160 |
+
self._log_error(f"Image Creation Error: {str(e)}. Width: {width}, Height: {height}")
|
161 |
else:
|
162 |
+
self._log_error(
|
163 |
+
f"Image Error: Insufficient data or invalid dimensions (w={width}, h={height})"
|
164 |
+
)
|
165 |
|
166 |
except Exception as e:
|
167 |
+
self._log_error(f"Image Extraction Error: {str(e)}")
|
168 |
return images
|
169 |
|
170 |
+
# --- LLM INFERENCE ---
|
171 |
def generate(self, api_key: str) -> bool:
|
172 |
+
"""Generates Q&A pairs using the selected LLM provider."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
try:
|
|
|
|
|
|
|
|
|
174 |
if not api_key:
|
175 |
st.error("API Key cannot be empty.")
|
176 |
return False
|
177 |
|
178 |
+
provider_cfg = self.providers[st.session_state.config["provider"]]
|
179 |
+
client_initializer = provider_cfg["client"]
|
180 |
+
|
181 |
+
if st.session_state.config["provider"] == "Google":
|
182 |
+
client = client_initializer(api_key)
|
183 |
if not client:
|
184 |
return False # Google config failed
|
185 |
else:
|
186 |
client = client_initializer(api_key)
|
187 |
|
188 |
for i, input_data in enumerate(st.session_state.inputs):
|
189 |
+
st.session_state.processing["progress"] = (i + 1) / len(st.session_state.inputs)
|
190 |
|
191 |
+
# Debugging: Display input data
|
192 |
+
st.write("--- Input Data ---")
|
193 |
+
st.write(input_data["text"])
|
194 |
|
195 |
+
if st.session_state.config["provider"] == "HuggingFace":
|
196 |
response = self._huggingface_inference(client, input_data)
|
197 |
+
elif st.session_state.config["provider"] == "Google":
|
198 |
+
response = self._google_inference(client, input_data)
|
199 |
else:
|
200 |
response = self._standard_inference(client, input_data)
|
201 |
|
202 |
if response:
|
203 |
+
# Debugging: Display raw response
|
204 |
+
st.write("--- Raw Response ---")
|
205 |
+
st.write(response)
|
206 |
+
|
207 |
+
st.session_state.qa_data.extend(self._parse_response(response, st.session_state.config["provider"]))
|
208 |
|
209 |
return True
|
210 |
+
|
211 |
except Exception as e:
|
212 |
+
self._log_error(f"Generation Error: {str(e)}")
|
213 |
return False
|
214 |
|
215 |
def _standard_inference(self, client, input_data):
|
216 |
+
"""Performs inference using OpenAI-compatible API."""
|
217 |
+
try:
|
|
|
|
|
218 |
return client.chat.completions.create(
|
219 |
+
model=st.session_state.config["model"],
|
220 |
+
messages=[{"role": "user", "content": self._build_prompt(input_data)}],
|
221 |
+
temperature=st.session_state.config["temperature"],
|
|
|
|
|
|
|
|
|
222 |
)
|
223 |
+
except Exception as e:
|
224 |
+
self._log_error(f"OpenAI Inference Error: {e}")
|
225 |
+
return None
|
226 |
|
227 |
def _huggingface_inference(self, client, input_data):
|
228 |
"""Performs inference using Hugging Face Inference API."""
|
229 |
try:
|
230 |
response = requests.post(
|
231 |
+
HF_API_URL + st.session_state.config["model"],
|
232 |
headers=client["headers"],
|
233 |
+
json={"inputs": self._build_prompt(input_data)},
|
234 |
)
|
235 |
+
response.raise_for_status()
|
236 |
return response.json()
|
237 |
except requests.exceptions.RequestException as e:
|
238 |
+
self._log_error(f"Hugging Face Inference Error: {e}")
|
239 |
return None
|
240 |
|
241 |
def _google_inference(self, client, input_data):
|
242 |
"""Performs inference using Google Generative AI API."""
|
243 |
try:
|
244 |
+
model = client(st.session_state.config["model"])
|
245 |
response = model.generate_content(
|
246 |
self._build_prompt(input_data),
|
247 |
+
generation_config=genai.types.GenerationConfig(temperature=st.session_state.config["temperature"]),
|
|
|
248 |
)
|
|
|
|
|
|
|
|
|
249 |
return response
|
250 |
except Exception as e:
|
251 |
+
self._log_error(f"Google GenAI Inference Error: {e}")
|
252 |
return None
|
253 |
|
254 |
+
# --- PROMPT ENGINEERING ---
|
255 |
def _build_prompt(self, input_data):
|
256 |
"""Builds the prompt for the LLM based on the input data type."""
|
257 |
+
base = (
|
258 |
+
"You are an expert in extracting question and answer pairs from documents. "
|
259 |
+
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries.\n"
|
260 |
+
"Each dictionary must have the keys 'question' and 'answer'.\n"
|
261 |
+
"The 'question' should be clear and concise, and the 'answer' should directly answer the question using only "
|
262 |
+
"information from the data. Do not hallucinate or invent information.\n"
|
263 |
+
"Answer from the exact same document, not outside from the document\n"
|
264 |
+
"Example JSON Output:\n"
|
265 |
+
'[{"question": "What is the capital of France?", "answer": "The capital of France is Paris."}, '
|
266 |
+
'{"question": "What is the highest mountain in the world?", "answer": "The highest mountain in the world is Mount Everest."}, '
|
267 |
+
'{"question": "What is the chemical symbol for gold?", "answer": "The chemical symbol for gold is Au."}]\n'
|
268 |
+
"Now, generate 3 Q&A pairs from this data:\n"
|
269 |
+
)
|
270 |
+
|
271 |
+
if input_data["meta"]["type"] == "csv":
|
272 |
+
return base + "Data:\n" + input_data["text"]
|
273 |
+
elif input_data["meta"]["type"] == "api":
|
274 |
+
return base + "API response:\n" + input_data["text"]
|
275 |
+
return base + input_data["text"]
|
276 |
+
|
277 |
+
# --- RESPONSE PARSING ---
|
278 |
+
def _parse_response(self, response: Any, provider: str) -> list[dict[str, str]]:
|
279 |
+
"""Parses the LLM response into a list of Q&A pairs."""
|
280 |
try:
|
281 |
+
response_text = ""
|
282 |
+
|
283 |
if provider == "HuggingFace":
|
284 |
+
response_text = response[0]["generated_text"]
|
285 |
+
return response_text
|
286 |
elif provider == "Google":
|
287 |
+
response_text = response.text.strip()
|
288 |
+
|
289 |
+
else: # OpenAI, Deepseek, Groq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
if not response or not response.choices or not response.choices[0].message.content:
|
291 |
+
self._log_error("Empty or malformed response from LLM.")
|
292 |
return []
|
293 |
|
294 |
+
response_text = response.choices[0].message.content
|
295 |
+
|
296 |
+
try:
|
297 |
+
json_output = json.loads(response_text)
|
298 |
+
|
299 |
+
if isinstance(json_output, list):
|
300 |
+
qa_pairs = json_output
|
301 |
+
elif isinstance(json_output, dict) and "questionList" in json_output:
|
302 |
+
qa_pairs = json_output["questionList"]
|
303 |
+
else:
|
304 |
+
self._log_error(f"Unexpected JSON structure: {response_text}")
|
305 |
+
return []
|
306 |
+
|
307 |
+
if not isinstance(qa_pairs, list):
|
308 |
+
self._log_error(f"Expected a list of QA pairs, but got: {type(qa_pairs)}")
|
309 |
return []
|
310 |
+
|
311 |
+
for pair in qa_pairs:
|
312 |
+
if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
|
313 |
+
self._log_error(f"Invalid QA pair structure: {pair}")
|
314 |
+
return []
|
315 |
+
|
316 |
+
return qa_pairs
|
317 |
+
|
318 |
+
except json.JSONDecodeError as e:
|
319 |
+
self._log_error(f"JSON Parse Error: {e}. Raw Response: {response_text}")
|
320 |
+
return []
|
321 |
+
|
322 |
except Exception as e:
|
323 |
+
self._log_error(f"Parse Error: {e}. Raw Response: {response}")
|
324 |
return []
|
325 |
|
326 |
+
def _log_error(self, message):
|
327 |
+
"""Logs an error message to Streamlit session state and displays it."""
|
328 |
+
st.session_state.processing["errors"].append(message)
|
329 |
st.error(message)
|
330 |
|
|
|
|
|
|
|
|
|
331 |
|
332 |
+
# --- STREAMLIT UI COMPONENTS ---
|
333 |
+
def input_sidebar(gen: SyntheticDataGenerator) -> str:
|
334 |
+
"""Creates the input sidebar in the Streamlit UI."""
|
|
|
|
|
|
|
335 |
with st.sidebar:
|
336 |
st.header("⚙️ Configuration")
|
337 |
|
|
|
338 |
provider = st.selectbox("Provider", list(gen.providers.keys()))
|
339 |
+
st.session_state.config["provider"] = provider # Update session state immediately
|
340 |
provider_cfg = gen.providers[provider]
|
341 |
|
342 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
343 |
+
st.session_state["api_key"] = api_key
|
344 |
|
345 |
model = st.selectbox("Model", provider_cfg["models"])
|
346 |
+
st.session_state.config["model"] = model # Update model selection
|
347 |
|
348 |
+
temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
349 |
+
st.session_state.config["temperature"] = temp # Update temperature
|
|
|
|
|
|
|
|
|
350 |
|
351 |
# Input Source Selection
|
352 |
st.header("🔗 Data Sources")
|
|
|
360 |
elif input_type == "csv":
|
361 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
362 |
if csv_file:
|
363 |
+
st.session_state.inputs.extend(gen.input_handlers["csv"](csv_file))
|
364 |
|
365 |
elif input_type == "api":
|
366 |
api_url = st.text_input("API Endpoint")
|
367 |
+
api_headers = st.text_area("API Headers (JSON format, optional)", height=API_HEADERS_HEIGHT)
|
368 |
headers = {}
|
369 |
try:
|
370 |
if api_headers:
|
|
|
379 |
db_query = st.text_area("Database Query")
|
380 |
db_table = st.text_input("Table Name (optional)")
|
381 |
if st.button("Add DB Input"):
|
382 |
+
st.session_state.inputs.extend(
|
383 |
+
gen.input_handlers["db"]({"connection": db_connection, "query": db_query, "table": db_table})
|
384 |
+
)
|
385 |
|
386 |
return api_key
|
387 |
|
|
|
|
|
|
|
388 |
|
389 |
+
def main_display(gen: SyntheticDataGenerator):
|
390 |
+
"""Creates the main display area in the Streamlit UI."""
|
|
|
391 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
392 |
|
|
|
393 |
col1, col2 = st.columns([3, 1])
|
394 |
with col1:
|
395 |
pdf_file = st.file_uploader("Upload Document", type=["pdf"])
|
396 |
if pdf_file:
|
397 |
+
st.session_state.inputs.extend(gen.input_handlers["pdf"](pdf_file))
|
398 |
|
|
|
399 |
with col2:
|
400 |
if st.button("Start Generation"):
|
401 |
with st.status("Processing..."):
|
402 |
+
if not st.session_state["api_key"]:
|
403 |
+
st.error("Please provide an API Key.")
|
404 |
else:
|
405 |
+
gen.generate(st.session_state["api_key"])
|
406 |
|
|
|
407 |
if st.session_state.qa_data:
|
408 |
st.header("Generated Data")
|
409 |
df = pd.DataFrame(st.session_state.qa_data)
|
410 |
st.dataframe(df)
|
411 |
|
412 |
+
st.download_button("Export CSV", df.to_csv(index=False), "synthetic_data.csv")
|
413 |
+
|
|
|
|
|
|
|
|
|
414 |
|
415 |
def main():
|
416 |
"""Main function to run the Streamlit application."""
|
|
|
418 |
api_key = input_sidebar(gen)
|
419 |
main_display(gen)
|
420 |
|
421 |
+
|
422 |
if __name__ == "__main__":
|
423 |
main()
|