mgbam commited on
Commit
e9a68df
·
verified ·
1 Parent(s): 7b16658

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -223
app.py CHANGED
@@ -1,156 +1,145 @@
1
  import streamlit as st
2
  import pdfplumber
3
- import pytesseract
4
  import pandas as pd
5
  import requests
6
  import json
7
  from PIL import Image
8
- from io import BytesIO
9
  from openai import OpenAI
10
- import google.generativeai as genai # Added Google GenAI
11
  import groq
12
  import sqlalchemy
13
  from typing import Dict, Any
14
 
15
- # Constants for Default Values and API URLs
16
  HF_API_URL = "https://api-inference.huggingface.co/models/"
17
- DEFAULT_TEMPERATURE = 0.1 # Lower Temperature
18
- MODEL = "mixtral-8x7b-32768" #constant string
 
 
19
 
20
  class SyntheticDataGenerator:
21
- """
22
- A class to generate synthetic Q&A data from various input sources using different LLM providers.
23
- """
24
 
25
  def __init__(self):
26
- """Initializes the SyntheticDataGenerator with supported providers, input handlers, and session state."""
 
 
 
 
 
27
  self.providers = {
28
  "Deepseek": {
29
  "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
30
- "models": ["deepseek-chat"]
31
  },
32
  "OpenAI": {
33
  "client": lambda key: OpenAI(api_key=key),
34
- "models": ["gpt-4-turbo"]
35
  },
36
  "Groq": {
37
  "client": lambda key: groq.Groq(api_key=key),
38
- "models": [MODEL]
39
  },
40
  "HuggingFace": {
41
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
42
- "models": ["gpt2", "llama-2"]
43
  },
44
- "Google": {
45
- "client": lambda key: self._configure_google_genai(key), # Using a custom configure function
46
- "models": ["gemini-pro"] # Use gemini-pro. Consider adding "gemini-pro" when released.
47
  },
48
  }
49
 
 
 
50
  self.input_handlers = {
51
  "pdf": self.handle_pdf,
52
  "text": self.handle_text,
53
  "csv": self.handle_csv,
54
  "api": self.handle_api,
55
- "db": self.handle_db
56
  }
57
 
58
- self.init_session()
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def _configure_google_genai(self, api_key: str):
61
  """Configures the Google Generative AI client."""
62
  try:
63
  genai.configure(api_key=api_key)
64
- return genai.GenerativeModel # return the model class, not an instantiation
65
  except Exception as e:
66
  st.error(f"Error configuring Google GenAI: {e}")
67
- return None # Important: Handle the case where configuration fails
68
-
69
- def init_session(self):
70
- """Initializes the Streamlit session state with default values."""
71
- session_defaults = {
72
- 'inputs': [],
73
- 'qa_data': [],
74
- 'processing': {
75
- 'stage': 'idle',
76
- 'progress': 0,
77
- 'errors': []
78
- },
79
- 'config': {
80
- 'provider': "Groq",
81
- 'model': MODEL,
82
- 'temperature': DEFAULT_TEMPERATURE
83
- }
84
- }
85
-
86
- for key, val in session_defaults.items():
87
- if key not in st.session_state:
88
- st.session_state[key] = val
89
 
90
- # Input Processors
91
  def handle_pdf(self, file):
92
- """Extracts text and images from a PDF file."""
93
- try:
94
  with pdfplumber.open(file) as pdf:
95
  extracted_data = []
96
  for i, page in enumerate(pdf.pages):
97
  page_text = page.extract_text() or ""
98
  page_images = self.process_images(page)
99
- extracted_data.append({
100
- "text": page_text,
101
- "images": page_images,
102
- "meta": {"type": "pdf", "page": i + 1}
103
- })
104
  return extracted_data
105
- except Exception as e:
106
- self.log_error(f"PDF Error: {str(e)}")
107
- return []
108
 
109
  def handle_text(self, text):
110
  """Handles manual text input."""
111
- return [{
112
- "text": text,
113
- "meta": {"type": "domain", "source": "manual"}
114
- }]
115
 
116
  def handle_csv(self, file):
117
  """Reads a CSV file and prepares data for Q&A generation."""
118
  try:
119
  df = pd.read_csv(file)
120
- return [{
121
- "text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
122
- "meta": {"type": "csv", "columns": list(df.columns)}
123
- } for _, row in df.iterrows()]
124
  except Exception as e:
125
- self.log_error(f"CSV Error: {str(e)}")
126
  return []
127
 
128
  def handle_api(self, config):
129
  """Fetches data from an API endpoint."""
130
  try:
131
- response = requests.get(config['url'], headers=config['headers'])
132
- response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
133
- return [{
134
- "text": json.dumps(response.json()),
135
- "meta": {"type": "api", "endpoint": config['url']}
136
- }]
137
  except requests.exceptions.RequestException as e:
138
- self.log_error(f"API Error: {str(e)}")
139
  return []
140
 
141
-
142
  def handle_db(self, config):
143
  """Connects to a database and executes a query."""
144
  try:
145
- engine = sqlalchemy.create_engine(config['connection'])
146
  with engine.connect() as conn:
147
- result = conn.execute(sqlalchemy.text(config['query']))
148
- return [{
149
- "text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
150
- "meta": {"type": "db", "table": config.get('table', '')}
151
- } for row in result]
 
 
 
152
  except Exception as e:
153
- self.log_error(f"DB Error: {str(e)}")
154
  return []
155
 
156
  def process_images(self, page):
@@ -158,211 +147,206 @@ class SyntheticDataGenerator:
158
  images = []
159
  for img in page.images:
160
  try:
161
- stream = img['stream']
162
- width = int(stream.get('Width', 0))
163
- height = int(stream.get('Height', 0))
164
- image_data = stream.get_data() # Get the image data
165
- if width > 0 and height > 0 and image_data: #CHECK image_data
 
166
  try:
167
  image = Image.frombytes("RGB", (width, height), image_data)
168
- images.append({
169
- "data": image,
170
- "meta": {"dims": (width, height)}
171
- })
172
  except Exception as e:
173
- self.log_error(f"Image Creation Error: {str(e)}") # Log specific image creation errors.
174
  else:
175
- self.log_error(f"Image Error: Insufficient image data or invalid dimensions (width={width}, height={height})")
176
-
 
177
 
178
  except Exception as e:
179
- self.log_error(f"Image Extraction Error: {str(e)}") # More general extraction error
180
  return images
181
 
182
- # Core Generation Engine
183
  def generate(self, api_key: str) -> bool:
184
- """
185
- Generates Q&A pairs using the selected LLM provider.
186
-
187
- Args:
188
- api_key (str): The API key for the selected LLM provider.
189
-
190
- Returns:
191
- bool: True if generation was successful, False otherwise.
192
- """
193
  try:
194
- provider_cfg = self.providers[st.session_state.config['provider']]
195
- client_initializer = provider_cfg["client"] #Get the client init function.
196
-
197
- # Check that the key is not an empty string
198
  if not api_key:
199
  st.error("API Key cannot be empty.")
200
  return False
201
 
202
- # Initialize the client
203
- if st.session_state.config['provider'] == "Google":
204
- client = client_initializer(api_key) # Client is the class
 
 
205
  if not client:
206
  return False # Google config failed
207
  else:
208
  client = client_initializer(api_key)
209
 
210
  for i, input_data in enumerate(st.session_state.inputs):
 
211
 
212
- st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
 
 
213
 
214
- if st.session_state.config['provider'] == "HuggingFace":
215
  response = self._huggingface_inference(client, input_data)
216
- elif st.session_state.config['provider'] == "Google":
217
- response = self._google_inference(client, input_data)
218
  else:
219
  response = self._standard_inference(client, input_data)
220
 
221
  if response:
222
- # Check if the parsing function needs access to the provider
223
- st.session_state.qa_data.extend(self._parse_response(response, st.session_state.config['provider']))
 
 
 
224
 
225
  return True
 
226
  except Exception as e:
227
- self.log_error(f"Generation Error: {str(e)}")
228
  return False
229
 
230
  def _standard_inference(self, client, input_data):
231
- """Performs inference using standard OpenAI-compatible API."""
232
- try:
233
-
234
- #st.write(input_data['text']) # debugging data
235
  return client.chat.completions.create(
236
- model=st.session_state.config['model'],
237
- messages=[{
238
- "role": "user",
239
- "content": self._build_prompt(input_data)
240
- }],
241
- temperature=st.session_state.config['temperature'],
242
- response_format={"type": "json_object"} #Request json
243
  )
244
- except Exception as e:
245
- self.log_error(f"OpenAI Inference Error: {e}")
246
- return None
247
 
248
  def _huggingface_inference(self, client, input_data):
249
  """Performs inference using Hugging Face Inference API."""
250
  try:
251
  response = requests.post(
252
- HF_API_URL + st.session_state.config['model'],
253
  headers=client["headers"],
254
- json={"inputs": self._build_prompt(input_data)}
255
  )
256
- response.raise_for_status() #Check for HTTP errors
257
  return response.json()
258
  except requests.exceptions.RequestException as e:
259
- self.log_error(f"Hugging Face Inference Error: {e}")
260
  return None
261
 
262
  def _google_inference(self, client, input_data):
263
  """Performs inference using Google Generative AI API."""
264
  try:
265
- model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
266
  response = model.generate_content(
267
  self._build_prompt(input_data),
268
- generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
269
-
270
  )
271
-
272
- st.write("Google API Response:") # Debugging: Print the raw response
273
- st.write(response.text)
274
-
275
  return response
276
  except Exception as e:
277
- self.log_error(f"Google GenAI Inference Error: {e}")
278
  return None
279
 
 
280
  def _build_prompt(self, input_data):
281
  """Builds the prompt for the LLM based on the input data type."""
282
- base = "Generate a JSON list of 3 dictionaries like this: \n"
283
- base+= '[{"question":"Example Question", "answer":"Example Answer"},'
284
- base+= '{"question":"Example Question", "answer":"Example Answer"},'
285
- base+= '{"question":"Example Question", "answer":"Example Answer"}]'
286
- base+= 'Here is the data:\n'
287
- if input_data['meta']['type'] == 'csv':
288
- return base + "Data:\n" + input_data['text']
289
- elif input_data['meta']['type'] == 'api':
290
- return base + "API response:\n" + input_data['text']
291
- return base + input_data['text']
292
-
293
- def _parse_response(self, response, provider):
294
- """Parses the response from the LLM into a list of Q&A pairs."""
 
 
 
 
 
 
 
 
 
 
295
  try:
 
 
296
  if provider == "HuggingFace":
297
- return response[0]['generated_text']
 
298
  elif provider == "Google":
299
- # Expecting a text response from Gemini
300
- try:
301
- json_string = response.text.strip() # Removes surrounding whitespace that can cause errors
302
- qa_pairs = json.loads(json_string).get("qa_pairs", []) # Extract the qa_pairs
303
-
304
- # Validate the structure of qa_pairs
305
- if not isinstance(qa_pairs, list):
306
- raise ValueError("Expected a list of QA pairs.")
307
-
308
- for pair in qa_pairs:
309
- if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
310
- raise ValueError("Each item in the list must be a dictionary with 'question' and 'answer' keys.")
311
- return qa_pairs # Return the extracted and validated list
312
- except (json.JSONDecodeError, ValueError) as e:
313
- self.log_error(f"Google JSON Parse Error: {e}. Raw Response: {response.text}")
314
- return [] # Return empty in case of parsing failure
315
- else:
316
- # Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
317
  if not response or not response.choices or not response.choices[0].message.content:
318
- self.log_error("Empty or malformed response from LLM.")
319
  return []
320
 
321
- try:
322
- json_output = json.loads(response.choices[0].message.content) # load the JSON data
323
- return json_output.get("qa_pairs", []) # Return the qa_pairs
324
- except json.JSONDecodeError as e:
325
- self.log_error(f"JSON Parse Error: {e}. Raw Response: {response.choices[0].message.content}")
 
 
 
 
 
 
 
 
 
 
326
  return []
 
 
 
 
 
 
 
 
 
 
 
 
327
  except Exception as e:
328
- self.log_error(f"Parse Error: {e}. Raw Response: {response}")
329
  return []
330
 
331
- def log_error(self, message):
332
- """Logs an error message to the Streamlit session state and displays it in the UI."""
333
- st.session_state.processing['errors'].append(message)
334
  st.error(message)
335
 
336
- # Streamlit UI Components
337
- def input_sidebar(gen: SyntheticDataGenerator):
338
- """
339
- Creates the input sidebar in the Streamlit UI.
340
 
341
- Args:
342
- gen (SyntheticDataGenerator): The SyntheticDataGenerator instance.
343
-
344
- Returns:
345
- str: The API key entered by the user.
346
- """
347
  with st.sidebar:
348
  st.header("⚙️ Configuration")
349
 
350
- # AI Provider Settings
351
  provider = st.selectbox("Provider", list(gen.providers.keys()))
 
352
  provider_cfg = gen.providers[provider]
353
 
354
  api_key = st.text_input(f"{provider} API Key", type="password")
355
- st.session_state['api_key'] = api_key #Store API Key
356
 
357
  model = st.selectbox("Model", provider_cfg["models"])
358
- temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) #Lower
359
 
360
- # Update session config
361
- st.session_state.config.update({
362
- "provider": provider,
363
- "model": model,
364
- "temperature": temp
365
- })
366
 
367
  # Input Source Selection
368
  st.header("🔗 Data Sources")
@@ -376,11 +360,11 @@ def input_sidebar(gen: SyntheticDataGenerator):
376
  elif input_type == "csv":
377
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
378
  if csv_file:
379
- st.session_state.inputs.extend(gen.input_handlers["csv"](csv_file))
380
 
381
  elif input_type == "api":
382
  api_url = st.text_input("API Endpoint")
383
- api_headers = st.text_area("API Headers (JSON format, optional)", height=50)
384
  headers = {}
385
  try:
386
  if api_headers:
@@ -395,47 +379,38 @@ def input_sidebar(gen: SyntheticDataGenerator):
395
  db_query = st.text_area("Database Query")
396
  db_table = st.text_input("Table Name (optional)")
397
  if st.button("Add DB Input"):
398
- st.session_state.inputs.extend(gen.input_handlers["db"]({"connection": db_connection, "query": db_query, "table": db_table}))
 
 
399
 
400
  return api_key
401
 
402
- def main_display(gen: SyntheticDataGenerator):
403
- """
404
- Creates the main display area in the Streamlit UI.
405
 
406
- Args:
407
- gen (SyntheticDataGenerator): The SyntheticDataGenerator instance.
408
- """
409
  st.title("🚀 Enterprise Synthetic Data Factory")
410
 
411
- # Input Processing
412
  col1, col2 = st.columns([3, 1])
413
  with col1:
414
  pdf_file = st.file_uploader("Upload Document", type=["pdf"])
415
  if pdf_file:
416
- st.session_state.inputs.extend(gen.input_handlers["pdf"](pdf_file))
417
 
418
- # Generation Controls
419
  with col2:
420
  if st.button("Start Generation"):
421
  with st.status("Processing..."):
422
- if not st.session_state.get('api_key'):
423
- st.error("Please provide an API Key.")
424
  else:
425
- gen.generate(st.session_state.get('api_key'))
426
 
427
- # Results Display
428
  if st.session_state.qa_data:
429
  st.header("Generated Data")
430
  df = pd.DataFrame(st.session_state.qa_data)
431
  st.dataframe(df)
432
 
433
- # Export Options
434
- st.download_button(
435
- "Export CSV",
436
- df.to_csv(index=False),
437
- "synthetic_data.csv"
438
- )
439
 
440
  def main():
441
  """Main function to run the Streamlit application."""
@@ -443,5 +418,6 @@ def main():
443
  api_key = input_sidebar(gen)
444
  main_display(gen)
445
 
 
446
  if __name__ == "__main__":
447
  main()
 
1
  import streamlit as st
2
  import pdfplumber
 
3
  import pandas as pd
4
  import requests
5
  import json
6
  from PIL import Image
 
7
  from openai import OpenAI
8
+ import google.generative_ai as genai
9
  import groq
10
  import sqlalchemy
11
  from typing import Dict, Any
12
 
13
+ # --- CONSTANTS ---
14
  HF_API_URL = "https://api-inference.huggingface.co/models/"
15
+ DEFAULT_TEMPERATURE = 0.1
16
+ MODEL = "mixtral-8x7b-32768" # Groq model
17
+ API_HEADERS_HEIGHT = 70 # Minimum height for st.text_area
18
+
19
 
20
  class SyntheticDataGenerator:
21
+ """Generates synthetic Q&A data from various input sources using LLMs."""
 
 
22
 
23
  def __init__(self):
24
+ self._setup_providers()
25
+ self._setup_input_handlers()
26
+ self._initialize_session_state()
27
+
28
+ def _setup_providers(self):
29
+ """Defines the available LLM providers and their configurations."""
30
  self.providers = {
31
  "Deepseek": {
32
  "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
33
+ "models": ["deepseek-chat"],
34
  },
35
  "OpenAI": {
36
  "client": lambda key: OpenAI(api_key=key),
37
+ "models": ["gpt-4-turbo"],
38
  },
39
  "Groq": {
40
  "client": lambda key: groq.Groq(api_key=key),
41
+ "models": [MODEL],
42
  },
43
  "HuggingFace": {
44
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
45
+ "models": ["gpt2", "llama-2"],
46
  },
47
+ "Google": {
48
+ "client": lambda key: self._configure_google_genai(key),
49
+ "models": ["gemini-pro"],
50
  },
51
  }
52
 
53
+ def _setup_input_handlers(self):
54
+ """Defines handlers for different input data types."""
55
  self.input_handlers = {
56
  "pdf": self.handle_pdf,
57
  "text": self.handle_text,
58
  "csv": self.handle_csv,
59
  "api": self.handle_api,
60
+ "db": self.handle_db,
61
  }
62
 
63
+ def _initialize_session_state(self):
64
+ """Initializes Streamlit session state variables."""
65
+ session_defaults = {
66
+ "inputs": [],
67
+ "qa_data": [],
68
+ "processing": {"stage": "idle", "progress": 0, "errors": []},
69
+ "config": {"provider": "Groq", "model": MODEL, "temperature": DEFAULT_TEMPERATURE},
70
+ "api_key": "", # Explicitly initialize api_key in session state
71
+ }
72
+ for key, value in session_defaults.items():
73
+ if key not in st.session_state:
74
+ st.session_state[key] = value
75
 
76
  def _configure_google_genai(self, api_key: str):
77
  """Configures the Google Generative AI client."""
78
  try:
79
  genai.configure(api_key=api_key)
80
+ return genai.GenerativeModel
81
  except Exception as e:
82
  st.error(f"Error configuring Google GenAI: {e}")
83
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # --- INPUT HANDLERS ---
86
  def handle_pdf(self, file):
87
+ """Extracts text and images from a PDF file."""
88
+ try:
89
  with pdfplumber.open(file) as pdf:
90
  extracted_data = []
91
  for i, page in enumerate(pdf.pages):
92
  page_text = page.extract_text() or ""
93
  page_images = self.process_images(page)
94
+ extracted_data.append(
95
+ {"text": page_text, "images": page_images, "meta": {"type": "pdf", "page": i + 1}}
96
+ )
 
 
97
  return extracted_data
98
+ except Exception as e:
99
+ self._log_error(f"PDF Error: {str(e)}")
100
+ return []
101
 
102
  def handle_text(self, text):
103
  """Handles manual text input."""
104
+ return [{"text": text, "meta": {"type": "domain", "source": "manual"}}]
 
 
 
105
 
106
  def handle_csv(self, file):
107
  """Reads a CSV file and prepares data for Q&A generation."""
108
  try:
109
  df = pd.read_csv(file)
110
+ return [
111
+ {"text": "\n".join([f"{col}: {row[col]}" for col in df.columns]), "meta": {"type": "csv", "columns": list(df.columns)}}
112
+ for _, row in df.iterrows()
113
+ ]
114
  except Exception as e:
115
+ self._log_error(f"CSV Error: {str(e)}")
116
  return []
117
 
118
  def handle_api(self, config):
119
  """Fetches data from an API endpoint."""
120
  try:
121
+ response = requests.get(config["url"], headers=config["headers"], timeout=10) # Add timeout
122
+ response.raise_for_status() # Raise HTTPError for bad responses
123
+ return [{"text": json.dumps(response.json()), "meta": {"type": "api", "endpoint": config["url"]}}]
 
 
 
124
  except requests.exceptions.RequestException as e:
125
+ self._log_error(f"API Error: {str(e)}")
126
  return []
127
 
 
128
  def handle_db(self, config):
129
  """Connects to a database and executes a query."""
130
  try:
131
+ engine = sqlalchemy.create_engine(config["connection"])
132
  with engine.connect() as conn:
133
+ result = conn.execute(sqlalchemy.text(config["query"]))
134
+ return [
135
+ {
136
+ "text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
137
+ "meta": {"type": "db", "table": config.get("table", "")},
138
+ }
139
+ for row in result
140
+ ]
141
  except Exception as e:
142
+ self._log_error(f"DB Error: {str(e)}")
143
  return []
144
 
145
  def process_images(self, page):
 
147
  images = []
148
  for img in page.images:
149
  try:
150
+ stream = img["stream"]
151
+ width = int(stream.get("Width", 0))
152
+ height = int(stream.get("Height", 0))
153
+ image_data = stream.get_data()
154
+
155
+ if width > 0 and height > 0 and image_data:
156
  try:
157
  image = Image.frombytes("RGB", (width, height), image_data)
158
+ images.append({"data": image, "meta": {"dims": (width, height)}})
 
 
 
159
  except Exception as e:
160
+ self._log_error(f"Image Creation Error: {str(e)}. Width: {width}, Height: {height}")
161
  else:
162
+ self._log_error(
163
+ f"Image Error: Insufficient data or invalid dimensions (w={width}, h={height})"
164
+ )
165
 
166
  except Exception as e:
167
+ self._log_error(f"Image Extraction Error: {str(e)}")
168
  return images
169
 
170
+ # --- LLM INFERENCE ---
171
  def generate(self, api_key: str) -> bool:
172
+ """Generates Q&A pairs using the selected LLM provider."""
 
 
 
 
 
 
 
 
173
  try:
 
 
 
 
174
  if not api_key:
175
  st.error("API Key cannot be empty.")
176
  return False
177
 
178
+ provider_cfg = self.providers[st.session_state.config["provider"]]
179
+ client_initializer = provider_cfg["client"]
180
+
181
+ if st.session_state.config["provider"] == "Google":
182
+ client = client_initializer(api_key)
183
  if not client:
184
  return False # Google config failed
185
  else:
186
  client = client_initializer(api_key)
187
 
188
  for i, input_data in enumerate(st.session_state.inputs):
189
+ st.session_state.processing["progress"] = (i + 1) / len(st.session_state.inputs)
190
 
191
+ # Debugging: Display input data
192
+ st.write("--- Input Data ---")
193
+ st.write(input_data["text"])
194
 
195
+ if st.session_state.config["provider"] == "HuggingFace":
196
  response = self._huggingface_inference(client, input_data)
197
+ elif st.session_state.config["provider"] == "Google":
198
+ response = self._google_inference(client, input_data)
199
  else:
200
  response = self._standard_inference(client, input_data)
201
 
202
  if response:
203
+ # Debugging: Display raw response
204
+ st.write("--- Raw Response ---")
205
+ st.write(response)
206
+
207
+ st.session_state.qa_data.extend(self._parse_response(response, st.session_state.config["provider"]))
208
 
209
  return True
210
+
211
  except Exception as e:
212
+ self._log_error(f"Generation Error: {str(e)}")
213
  return False
214
 
215
  def _standard_inference(self, client, input_data):
216
+ """Performs inference using OpenAI-compatible API."""
217
+ try:
 
 
218
  return client.chat.completions.create(
219
+ model=st.session_state.config["model"],
220
+ messages=[{"role": "user", "content": self._build_prompt(input_data)}],
221
+ temperature=st.session_state.config["temperature"],
 
 
 
 
222
  )
223
+ except Exception as e:
224
+ self._log_error(f"OpenAI Inference Error: {e}")
225
+ return None
226
 
227
  def _huggingface_inference(self, client, input_data):
228
  """Performs inference using Hugging Face Inference API."""
229
  try:
230
  response = requests.post(
231
+ HF_API_URL + st.session_state.config["model"],
232
  headers=client["headers"],
233
+ json={"inputs": self._build_prompt(input_data)},
234
  )
235
+ response.raise_for_status()
236
  return response.json()
237
  except requests.exceptions.RequestException as e:
238
+ self._log_error(f"Hugging Face Inference Error: {e}")
239
  return None
240
 
241
  def _google_inference(self, client, input_data):
242
  """Performs inference using Google Generative AI API."""
243
  try:
244
+ model = client(st.session_state.config["model"])
245
  response = model.generate_content(
246
  self._build_prompt(input_data),
247
+ generation_config=genai.types.GenerationConfig(temperature=st.session_state.config["temperature"]),
 
248
  )
 
 
 
 
249
  return response
250
  except Exception as e:
251
+ self._log_error(f"Google GenAI Inference Error: {e}")
252
  return None
253
 
254
+ # --- PROMPT ENGINEERING ---
255
  def _build_prompt(self, input_data):
256
  """Builds the prompt for the LLM based on the input data type."""
257
+ base = (
258
+ "You are an expert in extracting question and answer pairs from documents. "
259
+ "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries.\n"
260
+ "Each dictionary must have the keys 'question' and 'answer'.\n"
261
+ "The 'question' should be clear and concise, and the 'answer' should directly answer the question using only "
262
+ "information from the data. Do not hallucinate or invent information.\n"
263
+ "Answer from the exact same document, not outside from the document\n"
264
+ "Example JSON Output:\n"
265
+ '[{"question": "What is the capital of France?", "answer": "The capital of France is Paris."}, '
266
+ '{"question": "What is the highest mountain in the world?", "answer": "The highest mountain in the world is Mount Everest."}, '
267
+ '{"question": "What is the chemical symbol for gold?", "answer": "The chemical symbol for gold is Au."}]\n'
268
+ "Now, generate 3 Q&A pairs from this data:\n"
269
+ )
270
+
271
+ if input_data["meta"]["type"] == "csv":
272
+ return base + "Data:\n" + input_data["text"]
273
+ elif input_data["meta"]["type"] == "api":
274
+ return base + "API response:\n" + input_data["text"]
275
+ return base + input_data["text"]
276
+
277
+ # --- RESPONSE PARSING ---
278
+ def _parse_response(self, response: Any, provider: str) -> list[dict[str, str]]:
279
+ """Parses the LLM response into a list of Q&A pairs."""
280
  try:
281
+ response_text = ""
282
+
283
  if provider == "HuggingFace":
284
+ response_text = response[0]["generated_text"]
285
+ return response_text
286
  elif provider == "Google":
287
+ response_text = response.text.strip()
288
+
289
+ else: # OpenAI, Deepseek, Groq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  if not response or not response.choices or not response.choices[0].message.content:
291
+ self._log_error("Empty or malformed response from LLM.")
292
  return []
293
 
294
+ response_text = response.choices[0].message.content
295
+
296
+ try:
297
+ json_output = json.loads(response_text)
298
+
299
+ if isinstance(json_output, list):
300
+ qa_pairs = json_output
301
+ elif isinstance(json_output, dict) and "questionList" in json_output:
302
+ qa_pairs = json_output["questionList"]
303
+ else:
304
+ self._log_error(f"Unexpected JSON structure: {response_text}")
305
+ return []
306
+
307
+ if not isinstance(qa_pairs, list):
308
+ self._log_error(f"Expected a list of QA pairs, but got: {type(qa_pairs)}")
309
  return []
310
+
311
+ for pair in qa_pairs:
312
+ if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
313
+ self._log_error(f"Invalid QA pair structure: {pair}")
314
+ return []
315
+
316
+ return qa_pairs
317
+
318
+ except json.JSONDecodeError as e:
319
+ self._log_error(f"JSON Parse Error: {e}. Raw Response: {response_text}")
320
+ return []
321
+
322
  except Exception as e:
323
+ self._log_error(f"Parse Error: {e}. Raw Response: {response}")
324
  return []
325
 
326
+ def _log_error(self, message):
327
+ """Logs an error message to Streamlit session state and displays it."""
328
+ st.session_state.processing["errors"].append(message)
329
  st.error(message)
330
 
 
 
 
 
331
 
332
+ # --- STREAMLIT UI COMPONENTS ---
333
+ def input_sidebar(gen: SyntheticDataGenerator) -> str:
334
+ """Creates the input sidebar in the Streamlit UI."""
 
 
 
335
  with st.sidebar:
336
  st.header("⚙️ Configuration")
337
 
 
338
  provider = st.selectbox("Provider", list(gen.providers.keys()))
339
+ st.session_state.config["provider"] = provider # Update session state immediately
340
  provider_cfg = gen.providers[provider]
341
 
342
  api_key = st.text_input(f"{provider} API Key", type="password")
343
+ st.session_state["api_key"] = api_key
344
 
345
  model = st.selectbox("Model", provider_cfg["models"])
346
+ st.session_state.config["model"] = model # Update model selection
347
 
348
+ temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
349
+ st.session_state.config["temperature"] = temp # Update temperature
 
 
 
 
350
 
351
  # Input Source Selection
352
  st.header("🔗 Data Sources")
 
360
  elif input_type == "csv":
361
  csv_file = st.file_uploader("Upload CSV", type=["csv"])
362
  if csv_file:
363
+ st.session_state.inputs.extend(gen.input_handlers["csv"](csv_file))
364
 
365
  elif input_type == "api":
366
  api_url = st.text_input("API Endpoint")
367
+ api_headers = st.text_area("API Headers (JSON format, optional)", height=API_HEADERS_HEIGHT)
368
  headers = {}
369
  try:
370
  if api_headers:
 
379
  db_query = st.text_area("Database Query")
380
  db_table = st.text_input("Table Name (optional)")
381
  if st.button("Add DB Input"):
382
+ st.session_state.inputs.extend(
383
+ gen.input_handlers["db"]({"connection": db_connection, "query": db_query, "table": db_table})
384
+ )
385
 
386
  return api_key
387
 
 
 
 
388
 
389
+ def main_display(gen: SyntheticDataGenerator):
390
+ """Creates the main display area in the Streamlit UI."""
 
391
  st.title("🚀 Enterprise Synthetic Data Factory")
392
 
 
393
  col1, col2 = st.columns([3, 1])
394
  with col1:
395
  pdf_file = st.file_uploader("Upload Document", type=["pdf"])
396
  if pdf_file:
397
+ st.session_state.inputs.extend(gen.input_handlers["pdf"](pdf_file))
398
 
 
399
  with col2:
400
  if st.button("Start Generation"):
401
  with st.status("Processing..."):
402
+ if not st.session_state["api_key"]:
403
+ st.error("Please provide an API Key.")
404
  else:
405
+ gen.generate(st.session_state["api_key"])
406
 
 
407
  if st.session_state.qa_data:
408
  st.header("Generated Data")
409
  df = pd.DataFrame(st.session_state.qa_data)
410
  st.dataframe(df)
411
 
412
+ st.download_button("Export CSV", df.to_csv(index=False), "synthetic_data.csv")
413
+
 
 
 
 
414
 
415
  def main():
416
  """Main function to run the Streamlit application."""
 
418
  api_key = input_sidebar(gen)
419
  main_display(gen)
420
 
421
+
422
  if __name__ == "__main__":
423
  main()