mgbam commited on
Commit
fc2842c
·
verified ·
1 Parent(s): ad9d3f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -327
app.py CHANGED
@@ -4,446 +4,387 @@ import streamlit as st
4
  import pdfplumber
5
  import pandas as pd
6
  import sqlalchemy
7
- from PIL import Image
 
8
  from typing import Any, Dict, List
9
 
10
- # Provider clients
11
- from openai import OpenAI
12
- import google.generativeai as genai
13
- import groq
 
14
 
15
- # --- CONSTANTS ---
 
 
 
 
 
16
  HF_API_URL = "https://api-inference.huggingface.co/models/"
17
  DEFAULT_TEMPERATURE = 0.1
18
- GROQ_MODEL = "mixtral-8x7b-32768" # Groq model
19
- API_HEADERS_HEIGHT = 70 # Height for the API headers text area
20
 
21
 
22
- class SyntheticDataGenerator:
23
  """
24
- Generates synthetic Q&A data from various input sources using multiple LLM providers.
 
 
 
25
  """
26
  def __init__(self) -> None:
27
  self._setup_providers()
28
  self._setup_input_handlers()
29
  self._initialize_session_state()
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def _setup_providers(self) -> None:
32
- """Configure available LLM providers and their client initializations."""
33
  self.providers: Dict[str, Dict[str, Any]] = {
34
  "Deepseek": {
35
- "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
36
  "models": ["deepseek-chat"],
37
  },
38
  "OpenAI": {
39
- "client": lambda key: OpenAI(api_key=key),
40
- "models": ["gpt-4-turbo"],
41
  },
42
  "Groq": {
43
- "client": lambda key: groq.Groq(api_key=key),
44
  "models": [GROQ_MODEL],
45
  },
46
  "HuggingFace": {
47
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
48
  "models": ["gpt2", "llama-2"],
49
  },
50
- "Google": {
51
- "client": lambda key: self._configure_google_genai(key),
52
- "models": ["gemini-pro"],
53
- },
54
  }
55
 
56
  def _setup_input_handlers(self) -> None:
57
- """Define handlers for different input data types."""
58
  self.input_handlers: Dict[str, Any] = {
59
- "pdf": self.handle_pdf,
60
  "text": self.handle_text,
 
61
  "csv": self.handle_csv,
62
  "api": self.handle_api,
63
  "db": self.handle_db,
64
  }
65
 
66
  def _initialize_session_state(self) -> None:
67
- """Initialize Streamlit session state with default configurations."""
68
- session_defaults = {
69
- "inputs": [],
70
- "qa_data": [],
71
- "processing": {"stage": "idle", "progress": 0, "errors": []},
72
  "config": {
73
- "provider": "Groq",
74
- "model": GROQ_MODEL,
75
  "temperature": DEFAULT_TEMPERATURE,
 
76
  },
77
- "api_key": "", # Explicitly initialize the API key
 
 
 
 
78
  }
79
- for key, value in session_defaults.items():
80
  if key not in st.session_state:
81
  st.session_state[key] = value
82
 
83
- def _configure_google_genai(self, api_key: str) -> Any:
84
- """Configure and return the Google Generative AI client."""
85
- try:
86
- genai.configure(api_key=api_key)
87
- return genai.GenerativeModel
88
- except Exception as e:
89
- st.error(f"Error configuring Google GenAI: {e}")
90
- return None
91
 
92
- # --- INPUT HANDLERS ---
93
- def handle_pdf(self, file) -> List[Dict[str, Any]]:
94
- """
95
- Extract text and images from a PDF file.
96
 
97
- Returns:
98
- A list of dictionaries containing text, images, and metadata.
99
- """
100
  try:
101
  with pdfplumber.open(file) as pdf:
102
- extracted_data = []
103
- for i, page in enumerate(pdf.pages):
104
  page_text = page.extract_text() or ""
105
- page_images = self.process_images(page)
106
- extracted_data.append({
107
- "text": page_text,
108
- "images": page_images,
109
- "meta": {"type": "pdf", "page": i + 1},
110
- })
111
- return extracted_data
112
  except Exception as e:
113
- self._log_error(f"PDF Error: {e}")
114
- return []
115
 
116
- def handle_text(self, text: str) -> List[Dict[str, Any]]:
117
- """Handle manual text input."""
118
- return [{"text": text, "meta": {"type": "domain", "source": "manual"}}]
119
-
120
- def handle_csv(self, file) -> List[Dict[str, Any]]:
121
- """Process a CSV file and format the data for Q&A generation."""
122
  try:
123
  df = pd.read_csv(file)
124
- return [
125
- {
126
- "text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
127
- "meta": {"type": "csv", "columns": list(df.columns)},
128
- }
129
- for _, row in df.iterrows()
130
- ]
131
  except Exception as e:
132
- self._log_error(f"CSV Error: {e}")
133
- return []
134
 
135
- def handle_api(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
136
- """Fetch data from an API endpoint and format it for processing."""
137
  try:
138
- response = requests.get(config["url"], headers=config["headers"], timeout=10)
139
  response.raise_for_status()
140
- return [{
141
- "text": json.dumps(response.json()),
142
- "meta": {"type": "api", "endpoint": config["url"]},
143
- }]
144
- except requests.exceptions.RequestException as e:
145
- self._log_error(f"API Error: {e}")
146
- return []
147
-
148
- def handle_db(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
149
- """Connect to a database, execute a query, and format the results."""
150
  try:
151
  engine = sqlalchemy.create_engine(config["connection"])
152
  with engine.connect() as conn:
153
  result = conn.execute(sqlalchemy.text(config["query"]))
154
- return [
155
- {
156
- "text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
157
- "meta": {"type": "db", "table": config.get("table", "")},
158
- }
159
- for row in result
160
- ]
161
  except Exception as e:
162
- self._log_error(f"DB Error: {e}")
163
- return []
164
-
165
- def process_images(self, page) -> List[Dict[str, Any]]:
166
- """Extract and process images from a PDF page."""
167
- images = []
168
- for img in page.images:
169
- try:
170
- stream = img["stream"]
171
- width = int(stream.get("Width", 0))
172
- height = int(stream.get("Height", 0))
173
- image_data = stream.get_data()
174
- if width > 0 and height > 0 and image_data:
175
- try:
176
- image = Image.frombytes("RGB", (width, height), image_data)
177
- images.append({"data": image, "meta": {"dims": (width, height)}})
178
- except Exception as e:
179
- self._log_error(f"Image Creation Error: {e} (Width: {width}, Height: {height})")
180
- else:
181
- self._log_error(f"Image Error: Insufficient data or invalid dimensions (w={width}, h={height})")
182
- except Exception as e:
183
- self._log_error(f"Image Extraction Error: {e}")
184
- return images
185
 
186
- # --- LLM INFERENCE ---
187
- def generate(self, api_key: str) -> bool:
188
  """
189
- Generate Q&A pairs using the selected LLM provider.
190
-
191
- Iterates over all the input data, calls the appropriate inference method,
192
- and aggregates the generated Q&A pairs into session state.
193
  """
 
194
  if not api_key:
195
- st.error("API Key cannot be empty.")
196
  return False
197
 
198
- try:
199
- provider_name = st.session_state.config["provider"]
200
- provider_cfg = self.providers[provider_name]
201
- client_initializer = provider_cfg["client"]
202
-
203
- # Initialize the client
204
- if provider_name == "Google":
205
- client = client_initializer(api_key)
206
- if not client:
207
- return False
208
- else:
209
- client = client_initializer(api_key)
210
-
211
- for i, input_data in enumerate(st.session_state.inputs):
212
- st.session_state.processing["progress"] = (i + 1) / len(st.session_state.inputs)
213
- st.write("--- Input Data ---")
214
- st.write(input_data["text"])
215
 
216
- if provider_name == "HuggingFace":
217
- response = self._huggingface_inference(client, input_data)
218
- elif provider_name == "Google":
219
- response = self._google_inference(client, input_data)
220
- else:
221
- response = self._standard_inference(client, input_data)
222
 
223
- if response:
224
- st.write("--- Raw Response ---")
225
- st.write(response)
226
- parsed_response = self._parse_response(response, provider_name)
227
- if parsed_response:
228
- st.session_state.qa_data.extend(parsed_response)
 
229
 
 
 
230
  return True
231
-
232
  except Exception as e:
233
- self._log_error(f"Generation Error: {e}")
234
  return False
235
 
236
- def _standard_inference(self, client: Any, input_data: Dict[str, Any]) -> Any:
237
- """Perform inference using an OpenAI-compatible API."""
 
 
238
  try:
239
- return client.chat.completions.create(
240
- model=st.session_state.config["model"],
241
- messages=[{"role": "user", "content": self._build_prompt(input_data)}],
242
- temperature=st.session_state.config["temperature"],
243
  )
 
244
  except Exception as e:
245
- self._log_error(f"OpenAI Inference Error: {e}")
246
  return None
247
 
248
- def _huggingface_inference(self, client: Dict[str, Any], input_data: Dict[str, Any]) -> Any:
249
- """Perform inference using the Hugging Face Inference API."""
 
 
250
  try:
251
  response = requests.post(
252
- HF_API_URL + st.session_state.config["model"],
253
  headers=client["headers"],
254
- json={"inputs": self._build_prompt(input_data)},
 
255
  )
256
  response.raise_for_status()
257
  return response.json()
258
- except requests.exceptions.RequestException as e:
259
- self._log_error(f"Hugging Face Inference Error: {e}")
260
- return None
261
-
262
- def _google_inference(self, client: Any, input_data: Dict[str, Any]) -> Any:
263
- """Perform inference using the Google Generative AI API."""
264
- try:
265
- model = client(st.session_state.config["model"])
266
- response = model.generate_content(
267
- self._build_prompt(input_data),
268
- generation_config=genai.types.GenerationConfig(
269
- temperature=st.session_state.config["temperature"]
270
- ),
271
- )
272
- return response
273
  except Exception as e:
274
- self._log_error(f"Google GenAI Inference Error: {e}")
275
  return None
276
 
277
- # --- PROMPT ENGINEERING ---
278
- def _build_prompt(self, input_data: Dict[str, Any]) -> str:
279
- """
280
- Build the prompt for the LLM based on the input data.
281
-
282
- The prompt instructs the LLM to extract 3 Q&A pairs in JSON format.
283
- """
284
- base_prompt = (
285
- "You are an expert in extracting question and answer pairs from documents. "
286
- "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries.\n"
287
- "Each dictionary must have the keys 'question' and 'answer'.\n"
288
- "The 'question' should be clear and concise, and the 'answer' should directly answer the question "
289
- "using only information from the provided data. Do not hallucinate or invent information.\n"
290
- "Answer using the exact information from the document, not external knowledge.\n"
291
- "Example JSON Output:\n"
292
- '[{"question": "What is the capital of France?", "answer": "The capital of France is Paris."}, '
293
- '{"question": "What is the highest mountain in the world?", "answer": "The highest mountain in the world is Mount Everest."}, '
294
- '{"question": "What is the chemical symbol for gold?", "answer": "The chemical symbol for gold is Au."}]\n'
295
- "Now, generate 3 Q&A pairs from this data:\n"
296
- )
297
- data_type = input_data["meta"].get("type", "text")
298
- if data_type == "csv":
299
- return base_prompt + "Data:\n" + input_data["text"]
300
- elif data_type == "api":
301
- return base_prompt + "API response:\n" + input_data["text"]
302
- return base_prompt + input_data["text"]
303
-
304
- # --- RESPONSE PARSING ---
305
- def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
306
  """
307
- Parse the LLM response into a list of Q&A pairs.
308
-
309
- Expects the response to be a JSON formatted string.
310
  """
311
  try:
312
- response_text = ""
313
  if provider == "HuggingFace":
314
- response_text = response[0].get("generated_text", "")
315
- elif provider == "Google":
316
- response_text = response.text.strip()
317
- else: # OpenAI, Deepseek, Groq
318
- if not response or not response.choices or not response.choices[0].message.content:
319
- self._log_error("Empty or malformed response from LLM.")
320
- return []
321
- response_text = response.choices[0].message.content
322
-
323
- try:
324
- json_output = json.loads(response_text)
325
- except json.JSONDecodeError as e:
326
- self._log_error(f"JSON Parse Error: {e}. Raw Response: {response_text}")
327
- return []
328
-
329
- if isinstance(json_output, list):
330
- qa_pairs = json_output
331
- elif isinstance(json_output, dict) and "questionList" in json_output:
332
- qa_pairs = json_output["questionList"]
333
  else:
334
- self._log_error(f"Unexpected JSON structure: {response_text}")
335
- return []
336
-
337
- if not isinstance(qa_pairs, list):
338
- self._log_error(f"Expected a list of QA pairs, but got: {type(qa_pairs)}")
339
- return []
340
-
341
- for pair in qa_pairs:
342
- if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
343
- self._log_error(f"Invalid QA pair structure: {pair}")
344
- return []
345
-
346
- return qa_pairs
347
-
348
  except Exception as e:
349
- self._log_error(f"Parse Error: {e}. Raw Response: {response}")
350
- return []
351
 
352
- def _log_error(self, message: str) -> None:
353
- """Log an error message to the session state and display it."""
354
- st.session_state.processing["errors"].append(message)
355
- st.error(message)
356
 
 
357
 
358
- # --- STREAMLIT UI COMPONENTS ---
359
- def input_sidebar(generator: SyntheticDataGenerator) -> str:
360
- """Create the input sidebar in the Streamlit UI."""
361
  with st.sidebar:
362
- st.header("⚙️ Configuration")
363
- provider = st.selectbox("Provider", list(generator.providers.keys()))
364
- st.session_state.config["provider"] = provider # Update provider in session state
365
  provider_cfg = generator.providers[provider]
366
 
367
- api_key = st.text_input(f"{provider} API Key", type="password")
368
- st.session_state["api_key"] = api_key
369
-
370
- model = st.selectbox("Model", provider_cfg["models"])
371
  st.session_state.config["model"] = model
372
 
373
  temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
374
  st.session_state.config["temperature"] = temperature
375
 
376
- # Data Source Input
377
- st.header("🔗 Data Sources")
378
- input_type = st.selectbox("Input Type", list(generator.input_handlers.keys()))
379
 
380
- if input_type == "text":
381
- domain_input = st.text_area("Domain Knowledge", height=150)
382
- if st.button("Add Domain Input"):
383
- st.session_state.inputs.append(generator.input_handlers["text"](domain_input)[0])
384
-
385
- elif input_type == "csv":
386
- csv_file = st.file_uploader("Upload CSV", type=["csv"])
387
- if csv_file:
388
- st.session_state.inputs.extend(generator.input_handlers["csv"](csv_file))
389
-
390
- elif input_type == "api":
391
- api_url = st.text_input("API Endpoint")
392
- api_headers = st.text_area("API Headers (JSON format, optional)", height=API_HEADERS_HEIGHT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  headers = {}
394
- if api_headers:
395
- try:
396
  headers = json.loads(api_headers)
397
- except json.JSONDecodeError:
398
- st.error("Invalid JSON format for API headers.")
399
- if st.button("Add API Input"):
400
- st.session_state.inputs.extend(generator.input_handlers["api"]({"url": api_url, "headers": headers}))
401
-
402
- elif input_type == "db":
403
- db_connection = st.text_input("Database Connection String")
404
- db_query = st.text_area("Database Query")
405
- db_table = st.text_input("Table Name (optional)")
406
- if st.button("Add DB Input"):
407
- st.session_state.inputs.extend(generator.input_handlers["db"]({
408
- "connection": db_connection,
409
- "query": db_query,
410
- "table": db_table
411
- }))
412
-
413
- return api_key
414
-
415
-
416
- def main_display(generator: SyntheticDataGenerator) -> None:
417
- """Create the main display area in the Streamlit UI."""
418
- st.title("🚀 Enterprise Synthetic Data Factory")
419
-
420
- col1, col2 = st.columns([3, 1])
421
- with col1:
422
- pdf_file = st.file_uploader("Upload Document", type=["pdf"])
423
- if pdf_file:
424
- st.session_state.inputs.extend(generator.input_handlers["pdf"](pdf_file))
425
-
426
- with col2:
427
- if st.button("Start Generation"):
428
- with st.spinner("Processing..."):
429
- if not st.session_state["api_key"]:
430
- st.error("Please provide an API Key.")
431
- else:
432
- generator.generate(st.session_state["api_key"])
433
-
434
- if st.session_state.qa_data:
435
- st.header("Generated Data")
436
- df = pd.DataFrame(st.session_state.qa_data)
437
- st.dataframe(df)
438
- st.download_button("Export CSV", df.to_csv(index=False), "synthetic_data.csv")
439
-
440
 
441
  def main() -> None:
442
- """Main function to run the Streamlit application."""
443
- generator = SyntheticDataGenerator()
444
- _ = input_sidebar(generator)
445
- main_display(generator)
446
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  if __name__ == "__main__":
449
  main()
 
4
  import pdfplumber
5
  import pandas as pd
6
  import sqlalchemy
7
+ import time
8
+ import concurrent.futures
9
  from typing import Any, Dict, List
10
 
11
+ # Provider clients (make sure you have these installed)
12
+ try:
13
+ from openai import OpenAI
14
+ except ImportError:
15
+ OpenAI = None
16
 
17
+ try:
18
+ import groq
19
+ except ImportError:
20
+ groq = None
21
+
22
+ # Hugging Face inference URL
23
  HF_API_URL = "https://api-inference.huggingface.co/models/"
24
  DEFAULT_TEMPERATURE = 0.1
25
+ GROQ_MODEL = "mixtral-8x7b-32768"
 
26
 
27
 
28
+ class AdvancedSyntheticDataGenerator:
29
  """
30
+ Advanced Synthetic Data Generator
31
+
32
+ This class handles multiple input sources, advanced prompt engineering, and
33
+ supports multiple LLM providers to generate synthetic data.
34
  """
35
  def __init__(self) -> None:
36
  self._setup_providers()
37
  self._setup_input_handlers()
38
  self._initialize_session_state()
39
+ # A customizable prompt template (you can modify it via the UI)
40
+ self.custom_prompt_template = (
41
+ "You are an expert synthetic data generator. "
42
+ "Given the data below and following the instructions provided, generate high-quality, diverse synthetic data. "
43
+ "Ensure the output adheres to the specified format.\n\n"
44
+ "-------------------------\n"
45
+ "Data:\n{data}\n\n"
46
+ "Instructions:\n{instructions}\n\n"
47
+ "Output Format: {format}\n"
48
+ "-------------------------\n"
49
+ )
50
 
51
  def _setup_providers(self) -> None:
52
+ """Configure available LLM providers and their initialization routines."""
53
  self.providers: Dict[str, Dict[str, Any]] = {
54
  "Deepseek": {
55
+ "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
56
  "models": ["deepseek-chat"],
57
  },
58
  "OpenAI": {
59
+ "client": lambda key: OpenAI(api_key=key) if OpenAI else None,
60
+ "models": ["gpt-4-turbo", "gpt-3.5-turbo"],
61
  },
62
  "Groq": {
63
+ "client": lambda key: groq.Groq(api_key=key) if groq else None,
64
  "models": [GROQ_MODEL],
65
  },
66
  "HuggingFace": {
67
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
68
  "models": ["gpt2", "llama-2"],
69
  },
 
 
 
 
70
  }
71
 
72
  def _setup_input_handlers(self) -> None:
73
+ """Register handlers for different input data types."""
74
  self.input_handlers: Dict[str, Any] = {
 
75
  "text": self.handle_text,
76
+ "pdf": self.handle_pdf,
77
  "csv": self.handle_csv,
78
  "api": self.handle_api,
79
  "db": self.handle_db,
80
  }
81
 
82
  def _initialize_session_state(self) -> None:
83
+ """Initialize Streamlit session state with default configuration."""
84
+ defaults = {
 
 
 
85
  "config": {
86
+ "provider": "OpenAI",
87
+ "model": "gpt-4-turbo",
88
  "temperature": DEFAULT_TEMPERATURE,
89
+ "output_format": "plain_text", # Options: plain_text, json, csv
90
  },
91
+ "api_key": "",
92
+ "inputs": [], # A list to store input sources
93
+ "instructions": "", # Custom instructions for data generation
94
+ "synthetic_data": "", # The generated output
95
+ "error_logs": [], # Any errors that occur during processing
96
  }
97
+ for key, value in defaults.items():
98
  if key not in st.session_state:
99
  st.session_state[key] = value
100
 
101
+ def log_error(self, message: str) -> None:
102
+ """Log an error message both to the session state and in the UI."""
103
+ st.session_state.error_logs.append(message)
104
+ st.error(message)
 
 
 
 
105
 
106
+ # ===== INPUT HANDLERS =====
107
+ def handle_text(self, text: str) -> Dict[str, Any]:
108
+ return {"data": text, "source": "text"}
 
109
 
110
+ def handle_pdf(self, file) -> Dict[str, Any]:
 
 
111
  try:
112
  with pdfplumber.open(file) as pdf:
113
+ full_text = ""
114
+ for page in pdf.pages:
115
  page_text = page.extract_text() or ""
116
+ full_text += page_text + "\n"
117
+ return {"data": full_text, "source": "pdf"}
 
 
 
 
 
118
  except Exception as e:
119
+ self.log_error(f"PDF Processing Error: {e}")
120
+ return {"data": "", "source": "pdf"}
121
 
122
+ def handle_csv(self, file) -> Dict[str, Any]:
 
 
 
 
 
123
  try:
124
  df = pd.read_csv(file)
125
+ # For simplicity, we convert the dataframe to JSON.
126
+ return {"data": df.to_json(orient="records"), "source": "csv"}
 
 
 
 
 
127
  except Exception as e:
128
+ self.log_error(f"CSV Processing Error: {e}")
129
+ return {"data": "", "source": "csv"}
130
 
131
+ def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
 
132
  try:
133
+ response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
134
  response.raise_for_status()
135
+ return {"data": json.dumps(response.json()), "source": "api"}
136
+ except Exception as e:
137
+ self.log_error(f"API Processing Error: {e}")
138
+ return {"data": "", "source": "api"}
139
+
140
+ def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
 
 
 
 
141
  try:
142
  engine = sqlalchemy.create_engine(config["connection"])
143
  with engine.connect() as conn:
144
  result = conn.execute(sqlalchemy.text(config["query"]))
145
+ rows = [dict(row) for row in result]
146
+ return {"data": json.dumps(rows), "source": "db"}
 
 
 
 
 
147
  except Exception as e:
148
+ self.log_error(f"Database Processing Error: {e}")
149
+ return {"data": "", "source": "db"}
150
+
151
+ def aggregate_inputs(self) -> str:
152
+ """Combine all input sources into a single data string."""
153
+ aggregated_data = ""
154
+ for item in st.session_state.inputs:
155
+ aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
156
+ aggregated_data += item.get("data", "") + "\n\n"
157
+ return aggregated_data.strip()
158
+
159
+ def build_prompt(self) -> str:
160
+ """
161
+ Build the complete prompt by combining the aggregated input data with
162
+ custom instructions and the desired output format.
163
+ """
164
+ aggregated_data = self.aggregate_inputs()
165
+ instructions = st.session_state.instructions or "Generate diverse, coherent synthetic data."
166
+ output_format = st.session_state.config.get("output_format", "plain_text")
167
+ return self.custom_prompt_template.format(
168
+ data=aggregated_data, instructions=instructions, format=output_format
169
+ )
 
170
 
171
+ def generate_synthetic_data(self) -> bool:
 
172
  """
173
+ Generate synthetic data by sending the built prompt to the selected LLM provider.
174
+ Returns True if generation succeeds.
 
 
175
  """
176
+ api_key = st.session_state.api_key
177
  if not api_key:
178
+ self.log_error("API key is missing!")
179
  return False
180
 
181
+ provider_name = st.session_state.config["provider"]
182
+ provider_cfg = self.providers.get(provider_name)
183
+ if not provider_cfg:
184
+ self.log_error(f"Provider {provider_name} is not configured.")
185
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ client_initializer = provider_cfg["client"]
188
+ client = client_initializer(api_key)
189
+ model = st.session_state.config["model"]
190
+ temperature = st.session_state.config["temperature"]
191
+ prompt = self.build_prompt()
 
192
 
193
+ st.info(f"Using provider {provider_name} with model {model} at temperature {temperature:.2f}")
194
+ # (Optionally) simulate asynchronous processing with a thread pool if needed.
195
+ try:
196
+ if provider_name == "HuggingFace":
197
+ response = self._huggingface_inference(client, prompt, model)
198
+ else:
199
+ response = self._standard_inference(client, prompt, model, temperature)
200
 
201
+ synthetic_data = self._parse_response(response, provider_name)
202
+ st.session_state.synthetic_data = synthetic_data
203
  return True
 
204
  except Exception as e:
205
+ self.log_error(f"Generation failed: {e}")
206
  return False
207
 
208
+ def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
209
+ """
210
+ Inference method for providers using an OpenAI-compatible API.
211
+ """
212
  try:
213
+ result = client.chat.completions.create(
214
+ model=model,
215
+ messages=[{"role": "user", "content": prompt}],
216
+ temperature=temperature,
217
  )
218
+ return result
219
  except Exception as e:
220
+ self.log_error(f"Standard Inference Error: {e}")
221
  return None
222
 
223
+ def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
224
+ """
225
+ Inference method for the Hugging Face Inference API.
226
+ """
227
  try:
228
  response = requests.post(
229
+ HF_API_URL + model,
230
  headers=client["headers"],
231
+ json={"inputs": prompt},
232
+ timeout=30,
233
  )
234
  response.raise_for_status()
235
  return response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  except Exception as e:
237
+ self.log_error(f"HuggingFace Inference Error: {e}")
238
  return None
239
 
240
+ def _parse_response(self, response: Any, provider: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  """
242
+ Parse the LLM response into a synthetic data string.
 
 
243
  """
244
  try:
 
245
  if provider == "HuggingFace":
246
+ if isinstance(response, list) and "generated_text" in response[0]:
247
+ return response[0]["generated_text"]
248
+ else:
249
+ self.log_error("Unexpected HuggingFace response format.")
250
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  else:
252
+ if response and hasattr(response, "choices") and response.choices:
253
+ return response.choices[0].message.content
254
+ else:
255
+ self.log_error("Unexpected response format.")
256
+ return ""
 
 
 
 
 
 
 
 
 
257
  except Exception as e:
258
+ self.log_error(f"Response Parsing Error: {e}")
259
+ return ""
260
 
 
 
 
 
261
 
262
+ # ===== ADVANCED UI COMPONENTS =====
263
 
264
+ def advanced_config_ui(generator: AdvancedSyntheticDataGenerator):
265
+ """Advanced configuration options in the sidebar."""
 
266
  with st.sidebar:
267
+ st.header("Advanced Configuration")
268
+ provider = st.selectbox("Select Provider", list(generator.providers.keys()))
269
+ st.session_state.config["provider"] = provider
270
  provider_cfg = generator.providers[provider]
271
 
272
+ model = st.selectbox("Select Model", provider_cfg["models"])
 
 
 
273
  st.session_state.config["model"] = model
274
 
275
  temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
276
  st.session_state.config["temperature"] = temperature
277
 
278
+ output_format = st.radio("Output Format", ["plain_text", "json", "csv"])
279
+ st.session_state.config["output_format"] = output_format
 
280
 
281
+ api_key = st.text_input(f"{provider} API Key", type="password")
282
+ st.session_state.api_key = api_key
283
+
284
+ instructions = st.text_area("Custom Instructions",
285
+ "Generate diverse, coherent synthetic data based on the input sources.",
286
+ height=100)
287
+ st.session_state.instructions = instructions
288
+
289
+ def advanced_input_ui(generator: AdvancedSyntheticDataGenerator):
290
+ """UI for adding input sources using tabs."""
291
+ st.header("Input Data Sources")
292
+ tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
293
+
294
+ with tabs[0]:
295
+ text_input = st.text_area("Enter text input", height=150)
296
+ if st.button("Add Text Input", key="text_input"):
297
+ if text_input.strip():
298
+ st.session_state.inputs.append(generator.handle_text(text_input))
299
+ st.success("Text input added!")
300
+
301
+ with tabs[1]:
302
+ pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
303
+ if pdf_file is not None:
304
+ st.session_state.inputs.append(generator.handle_pdf(pdf_file))
305
+ st.success("PDF input added!")
306
+
307
+ with tabs[2]:
308
+ csv_file = st.file_uploader("Upload CSV", type=["csv"])
309
+ if csv_file is not None:
310
+ st.session_state.inputs.append(generator.handle_csv(csv_file))
311
+ st.success("CSV input added!")
312
+
313
+ with tabs[3]:
314
+ api_url = st.text_input("API Endpoint URL")
315
+ api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
316
+ if st.button("Add API Input", key="api_input"):
317
  headers = {}
318
+ try:
319
+ if api_headers:
320
  headers = json.loads(api_headers)
321
+ except Exception as e:
322
+ generator.log_error(f"Invalid JSON for API Headers: {e}")
323
+ st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
324
+ st.success("API input added!")
325
+
326
+ with tabs[4]:
327
+ db_conn = st.text_input("Database Connection String")
328
+ db_query = st.text_area("Database Query", height=100)
329
+ if st.button("Add Database Input", key="db_input"):
330
+ st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
331
+ st.success("Database input added!")
332
+
333
+ def advanced_output_ui(generator: AdvancedSyntheticDataGenerator):
334
+ """Display the generated synthetic data with various output options."""
335
+ st.header("Synthetic Data Output")
336
+ if st.session_state.synthetic_data:
337
+ output_format = st.session_state.config.get("output_format", "plain_text")
338
+ if output_format == "json":
339
+ try:
340
+ json_output = json.loads(st.session_state.synthetic_data)
341
+ st.json(json_output)
342
+ except Exception:
343
+ st.text_area("Output", st.session_state.synthetic_data, height=300)
344
+ else:
345
+ st.text_area("Output", st.session_state.synthetic_data, height=300)
346
+ st.download_button("Download Output", st.session_state.synthetic_data,
347
+ file_name="synthetic_data.txt", mime="text/plain")
348
+ else:
349
+ st.info("No synthetic data generated yet.")
350
+
351
+ def advanced_logs_ui():
352
+ """Display error logs and debug information in an expandable section."""
353
+ with st.expander("Error Logs & Debug Info", expanded=False):
354
+ if st.session_state.error_logs:
355
+ for log in st.session_state.error_logs:
356
+ st.write(log)
357
+ else:
358
+ st.write("No logs yet.")
359
+
360
+
361
+ # ===== MAIN APPLICATION =====
 
 
362
 
363
  def main() -> None:
364
+ st.set_page_config(page_title="Advanced Synthetic Data Generator", layout="wide")
365
+ generator = AdvancedSyntheticDataGenerator()
366
+ advanced_config_ui(generator)
367
+
368
+ # Create main tabs for Input, Output, and Logs
369
+ main_tabs = st.tabs(["Input", "Output", "Logs"])
370
+
371
+ with main_tabs[0]:
372
+ advanced_input_ui(generator)
373
+ if st.button("Clear Inputs"):
374
+ st.session_state.inputs = []
375
+ st.success("Inputs cleared!")
376
+
377
+ with main_tabs[1]:
378
+ if st.button("Generate Synthetic Data"):
379
+ with st.spinner("Generating synthetic data..."):
380
+ if generator.generate_synthetic_data():
381
+ st.success("Data generated successfully!")
382
+ else:
383
+ st.error("Data generation failed. Check logs for details.")
384
+ advanced_output_ui(generator)
385
+
386
+ with main_tabs[2]:
387
+ advanced_logs_ui()
388
 
389
  if __name__ == "__main__":
390
  main()