mgbam commited on
Commit
8cd330b
·
verified ·
1 Parent(s): 5f0d3d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -100
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  import sqlalchemy
6
  from typing import Any, Dict, List, Optional
7
  from functools import lru_cache
8
- import os # Import the 'os' module
9
 
10
  # Provider clients with import guards
11
  try:
@@ -20,17 +20,18 @@ except ImportError:
20
 
21
  try:
22
  import google.generativeai as genai
23
- from google.generativeai import GenerativeModel, configure
24
  except ImportError:
25
  GenerativeModel = None
26
  configure = None
27
- genai = None #Also set this to none
 
28
 
29
- import json # Ensure json is explicitly imported for enhanced use
30
 
31
  class SyntheticDataGenerator:
32
  """World's Most Advanced Synthetic Data Generation System"""
33
-
34
  PROVIDER_CONFIG = {
35
  "Deepseek": {
36
  "base_url": "https://api.deepseek.com/v1",
@@ -53,7 +54,7 @@ class SyntheticDataGenerator:
53
  "requires_library": None
54
  },
55
  "Google": {
56
- "models": ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"], # Include Gemini 2.0 Flash
57
  "requires_library": "google.generativeai"
58
  }
59
  }
@@ -76,14 +77,15 @@ class SyntheticDataGenerator:
76
  "error_count": 0
77
  },
78
  "debug_mode": False,
79
- "google_configured": False, # Track if Google API is configured
80
- "advanced_options": { # Store advanced generation options
81
- "temperature": 0.7, # Default temperature
82
- "top_p": 0.95, # Default top_p
83
- "top_k": 40, # Default top_k
84
- "max_output_tokens": 2000 # Default max_output_tokens
85
  },
86
- "generation_format": "json" # Default output format (json or text)
 
87
  }
88
  for key, val in defaults.items():
89
  if key not in st.session_state:
@@ -94,7 +96,7 @@ class SyntheticDataGenerator:
94
  self.available_providers = []
95
  for provider, config in self.PROVIDER_CONFIG.items():
96
  if config["requires_library"] and not globals().get(config["requires_library"].split('.')[0].title()):
97
- continue # Skip providers with missing dependencies
98
  self.available_providers.append(provider)
99
 
100
  def _setup_input_handlers(self):
@@ -106,12 +108,12 @@ class SyntheticDataGenerator:
106
  "api": self._process_api,
107
  "database": self._process_database,
108
  "web": self._process_web,
109
- "image": self._process_image #Add Image
110
  }
111
 
112
  # --- Core Generation Engine ---
113
  @lru_cache(maxsize=100)
114
- def generate(self, provider: str, model: str, prompt: str) -> Dict[str, Any]:
115
  """Unified generation endpoint with failover support"""
116
  try:
117
  if provider not in self.available_providers:
@@ -132,7 +134,7 @@ class SyntheticDataGenerator:
132
  config = self.PROVIDER_CONFIG[provider]
133
  api_key = st.session_state.api_keys.get(provider, "")
134
 
135
- if not api_key and provider != "Google": #Google API key is configured by configure()
136
  raise ValueError("API key required")
137
 
138
  try:
@@ -142,16 +144,18 @@ class SyntheticDataGenerator:
142
  return {"headers": {"Authorization": f"Bearer {api_key}"}}
143
  elif provider == "Google":
144
  if not st.session_state.google_configured:
145
- # Check if the API key is set as an environment variable
146
- if "GOOGLE_API_KEY" in os.environ:
147
- api_key = os.environ["GOOGLE_API_KEY"]
148
- else:
149
- # Use the API key from session state if available
150
- api_key = st.session_state.api_keys.get("Google", "")
151
- if not api_key:
152
- raise ValueError("Google API key is required. Please set it in the app or as the GOOGLE_API_KEY environment variable.")
153
- configure(api_key=api_key) #Configure the Google API key. Only do once
154
- st.session_state.google_configured = True
 
 
155
 
156
  generation_config = genai.GenerationConfig(
157
  temperature=st.session_state.advanced_options["temperature"],
@@ -159,7 +163,25 @@ class SyntheticDataGenerator:
159
  top_k=st.session_state.advanced_options["top_k"],
160
  max_output_tokens=st.session_state.advanced_options["max_output_tokens"]
161
  )
162
- return GenerativeModel(model_name=model, generation_config=generation_config) # Create the GenerativeModel with generation config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  else:
164
  return OpenAI(
165
  base_url=config["base_url"],
@@ -170,7 +192,8 @@ class SyntheticDataGenerator:
170
  self._log_error(f"Client Init Failed: {str(e)}")
171
  return None
172
 
173
- def _execute_generation(self, client, provider: str, model: str, prompt: str) -> Dict[str, Any]:
 
174
  """Execute provider-specific generation with circuit breaker"""
175
  st.session_state.system_metrics["api_calls"] += 1
176
 
@@ -185,16 +208,25 @@ class SyntheticDataGenerator:
185
  return response.json()
186
  elif provider == "Google":
187
  try:
188
- response = client.generate_content(prompt)
 
 
 
 
 
 
 
 
189
  content = response.text
190
 
191
- if st.session_state.generation_format == "json": # Check requested format
192
  try:
193
- return json.loads(content) # Attempt to parse as JSON
194
  except json.JSONDecodeError:
195
- return {"content": content, "warning": "Could not parse response as valid JSON. Returning raw text."} #Return raw content with warning
 
196
  else:
197
- return {"content": content} # Return raw content
198
 
199
  except Exception as e:
200
  self._log_error(f"Google Generation Error: {str(e)}")
@@ -203,21 +235,22 @@ class SyntheticDataGenerator:
203
  completion = client.chat.completions.create(
204
  model=model,
205
  messages=[{"role": "user", "content": prompt}],
206
- temperature=st.session_state.advanced_options["temperature"], #Use temp from session
207
  max_tokens=st.session_state.advanced_options["max_output_tokens"]
208
  )
209
  st.session_state.system_metrics["tokens_used"] += completion.usage.total_tokens
210
  try:
211
  return json.loads(completion.choices[0].message.content)
212
  except json.JSONDecodeError:
213
- return {"content": completion.choices[0].message.content, "warning": "Could not parse response as valid JSON. Returning raw text."}
 
214
 
215
  def _failover_generation(self, prompt: str) -> Dict[str, Any]:
216
  """Enterprise failover to secondary providers"""
217
  for backup_provider in self.available_providers:
218
  if backup_provider != st.session_state.active_provider:
219
  try:
220
- return self.generate(backup_provider, ..., prompt=prompt) # Corrected: include prompt
221
  except Exception:
222
  continue
223
  raise RuntimeError("All generation providers unavailable")
@@ -244,26 +277,24 @@ class SyntheticDataGenerator:
244
  return ""
245
 
246
  def _process_csv(self, file) -> str:
247
- """Process CSV files and return as a string representation."""
248
- try:
249
- df = pd.read_csv(file)
250
-
251
- # Attempt to infer a schema for the synthetic data generation
252
- column_names = df.columns.tolist()
253
- data_types = [str(df[col].dtype) for col in df.columns]
254
- schema_prompt = f"Column Names: {column_names}\nData Types: {data_types}"
255
- st.session_state.csv_schema = schema_prompt # Store the schema
256
-
257
- return df.to_string() # Convert DataFrame to string
258
- except Exception as e:
259
- self._log_error(f"CSV Processing Error: {str(e)}")
260
- return ""
261
 
262
  def _process_text(self, text: str) -> str:
263
  """Simple text passthrough processor"""
264
  return text
265
 
266
- def _process_api(self, url: str, method="GET", headers: Optional[Dict[str, str]] = None, data: Optional[Dict[str, Any]] = None) -> str:
 
267
  """Generic API endpoint processor with configurable methods and headers."""
268
  try:
269
  if method.upper() == "GET":
@@ -272,12 +303,12 @@ class SyntheticDataGenerator:
272
  response = requests.post(url, headers=headers or {}, json=data, timeout=10)
273
  else:
274
  raise ValueError("Unsupported HTTP method.")
275
- response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
276
 
277
  try:
278
- return json.dumps(response.json(), indent=2) # Pretty print JSON if possible
279
  except json.JSONDecodeError:
280
- return response.text # Otherwise, return raw text
281
  except requests.exceptions.RequestException as e:
282
  self._log_error(f"API Processing Error: {str(e)}")
283
  return ""
@@ -294,18 +325,16 @@ class SyntheticDataGenerator:
294
  self._log_error(f"Database Processing Error: {str(e)}")
295
  return ""
296
 
297
- def _process_image(self, image_file) -> str:
298
- """Processes image files for multimodal generation"""
299
  try:
300
- # For Google's Gemini, you need to prepare the image in a specific format
301
  image_data = image_file.read()
302
- image_part = {"mime_type": image_file.type, "data": image_data}
303
- st.session_state.image_part = image_part #Store image part
304
- return "Image uploaded. Include instructions for processing the image in your prompt." # Basic instruction to the LLM
305
 
306
  except Exception as e:
307
  self._log_error(f"Image Processing Error: {str(e)}")
308
- return ""
309
 
310
  # --- Enterprise Features ---
311
  def _log_error(self, message: str) -> None:
@@ -340,21 +369,24 @@ class SyntheticDataGenerator:
340
  return response.status_code == 200
341
  elif provider == "Google":
342
  try:
343
- #Need to initialize before listing models
344
- if not st.session_state.google_configured:
345
- api_key = st.session_state.api_keys.get("Google", "")
346
- if not api_key:
347
- api_key = os.environ.get("GOOGLE_API_KEY") #Check env variables
 
 
348
  if not api_key:
349
- return False
350
 
351
  configure(api_key=api_key) #Configure API Key
352
  st.session_state.google_configured = True
 
353
 
354
- genai.GenerativeModel(model_name=self.PROVIDER_CONFIG["Google"]["models"][0]).generate_content("test") #Send a test query
355
- return True #Connected if made it this far
356
 
357
- except Exception as e:
358
  print(e)
359
  return False
360
 
@@ -395,30 +427,46 @@ def provider_config_ui(gen: SyntheticDataGenerator):
395
  )
396
  st.session_state.active_model = model
397
 
398
- # Advanced Options (for providers that support it)
399
- if provider == "Google" or provider == "OpenAI": #Only add if OpenAI
400
  st.subheader("Advanced Generation Options")
401
- st.session_state.advanced_options["temperature"] = st.slider("Temperature", min_value=0.0, max_value=1.0, value=st.session_state.advanced_options["temperature"], step=0.05, help="Controls randomness. Lower values = more deterministic.")
402
-
 
 
 
403
  if provider == "Google":
404
- st.session_state.advanced_options["top_p"] = st.slider("Top P", min_value=0.0, max_value=1.0, value=st.session_state.advanced_options["top_p"], step=0.05, help="Nucleus sampling: Considers the most probable tokens.")
405
- st.session_state.advanced_options["top_k"] = st.slider("Top K", min_value=1, max_value=100, value=st.session_state.advanced_options["top_k"], step=1, help="Considers the top K most probable tokens.")
406
-
407
- st.session_state.advanced_options["max_output_tokens"] = st.number_input("Max Output Tokens", min_value=50, max_value=4096, value=st.session_state.advanced_options["max_output_tokens"], step=50, help="Maximum number of tokens in the generated output.")
408
-
409
- # Output format
410
- st.session_state.generation_format = st.selectbox("Output Format", ["json", "text"], help="Choose the desired output format.")
 
 
 
 
 
 
 
 
 
 
411
 
412
  # System monitoring
413
  if st.button("Run Health Check"):
414
  report = gen.health_check()
415
  st.json(report)
416
 
 
417
  def input_ui():
418
  """Creates the input method UI"""
419
- input_method = st.selectbox("Input Method", ["Text", "PDF", "Web URL", "CSV", "Image", "Structured Prompt (Advanced)"]) #Add Image input, Add Structured Prompt (Advanced)
 
 
420
  input_content = None
421
- additional_instructions = "" #For structured prompt
422
 
423
  if input_method == "Text":
424
  input_content = st.text_area("Enter Text", height=200)
@@ -435,7 +483,7 @@ def input_ui():
435
  input_content = uploaded_file
436
  if "csv_schema" in st.session_state:
437
  st.write("Inferred CSV Schema:")
438
- st.write(st.session_state.csv_schema) #Display inferred schema
439
 
440
  elif input_method == "Image":
441
  uploaded_file = st.file_uploader("Upload an Image file", type=["png", "jpg", "jpeg"])
@@ -445,9 +493,11 @@ def input_ui():
445
  elif input_method == "Structured Prompt (Advanced)":
446
  st.subheader("Structured Prompt")
447
  input_content = st.text_area("Enter the base prompt/instructions", height=100)
448
- additional_instructions = st.text_area("Specify constraints, data format, or other requirements:", height=100)
 
 
 
449
 
450
- return input_method, input_content, additional_instructions #Also return additional instructions
451
 
452
  def main():
453
  """Enterprise-grade user interface"""
@@ -467,10 +517,10 @@ def main():
467
 
468
  provider_config_ui(gen)
469
 
470
- input_method, input_content, additional_instructions = input_ui() #Get additonal instructions
471
 
472
  if st.button("Generate Data"):
473
- if input_content or input_method == "Structured Prompt (Advanced)": #Allow generation with *just* structured prompt
474
  processed_input = None
475
 
476
  if input_method == "Text":
@@ -482,23 +532,24 @@ def main():
482
  elif input_method == "CSV":
483
  processed_input = gen._process_csv(input_content)
484
  elif input_method == "Image":
485
- processed_input = gen._process_image(input_content)
 
 
 
 
486
  elif input_method == "Structured Prompt (Advanced)":
487
- processed_input = input_content + "\n" + additional_instructions #Combine instructions and constraints
488
- #st.write("Combined Prompt:")
489
- #st.write(processed_input) #Debug
490
 
491
  if processed_input:
492
  try:
493
- #Handle Google image case - requires a list of content. Other providers just use the text
494
  if st.session_state.active_provider == "Google" and input_method == "Image":
495
- prompt_parts = [processed_input, st.session_state.image_part] # Image part already stored
496
- result = gen.generate(st.session_state.active_provider, st.session_state.active_model, prompt_parts) # Process Google Images
497
  else:
498
- result = gen.generate(st.session_state.active_provider, st.session_state.active_model, processed_input) # Generic text case
499
 
500
  st.subheader("Generated Output:")
501
- st.json(result) # Display the JSON output
502
  except Exception as e:
503
  st.error(f"Error during generation: {e}")
504
  else:
@@ -506,7 +557,5 @@ def main():
506
  else:
507
  st.warning("Please provide input data.")
508
 
509
- # Input management and generation UI components...
510
-
511
  if __name__ == "__main__":
512
  main()
 
5
  import sqlalchemy
6
  from typing import Any, Dict, List, Optional
7
  from functools import lru_cache
8
+ import os
9
 
10
  # Provider clients with import guards
11
  try:
 
20
 
21
  try:
22
  import google.generativeai as genai
23
+ from google.generativeai import GenerativeModel, configure, Part
24
  except ImportError:
25
  GenerativeModel = None
26
  configure = None
27
+ genai = None
28
+ Part = None
29
 
30
+ import json
31
 
32
  class SyntheticDataGenerator:
33
  """World's Most Advanced Synthetic Data Generation System"""
34
+
35
  PROVIDER_CONFIG = {
36
  "Deepseek": {
37
  "base_url": "https://api.deepseek.com/v1",
 
54
  "requires_library": None
55
  },
56
  "Google": {
57
+ "models": ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-pro", "gemini-pro-vision"],
58
  "requires_library": "google.generativeai"
59
  }
60
  }
 
77
  "error_count": 0
78
  },
79
  "debug_mode": False,
80
+ "google_configured": False,
81
+ "advanced_options": {
82
+ "temperature": 0.7,
83
+ "top_p": 0.95,
84
+ "top_k": 40,
85
+ "max_output_tokens": 2000
86
  },
87
+ "generation_format": "json",
88
+ "csv_schema": ""
89
  }
90
  for key, val in defaults.items():
91
  if key not in st.session_state:
 
96
  self.available_providers = []
97
  for provider, config in self.PROVIDER_CONFIG.items():
98
  if config["requires_library"] and not globals().get(config["requires_library"].split('.')[0].title()):
99
+ continue
100
  self.available_providers.append(provider)
101
 
102
  def _setup_input_handlers(self):
 
108
  "api": self._process_api,
109
  "database": self._process_database,
110
  "web": self._process_web,
111
+ "image": self._process_image
112
  }
113
 
114
  # --- Core Generation Engine ---
115
  @lru_cache(maxsize=100)
116
+ def generate(self, provider: str, model: str, prompt: Any) -> Dict[str, Any]: # Allow "prompt" to be a list or a string
117
  """Unified generation endpoint with failover support"""
118
  try:
119
  if provider not in self.available_providers:
 
134
  config = self.PROVIDER_CONFIG[provider]
135
  api_key = st.session_state.api_keys.get(provider, "")
136
 
137
+ if not api_key and provider != "Google":
138
  raise ValueError("API key required")
139
 
140
  try:
 
144
  return {"headers": {"Authorization": f"Bearer {api_key}"}}
145
  elif provider == "Google":
146
  if not st.session_state.google_configured:
147
+ if "GOOGLE_API_KEY" in os.environ:
148
+ api_key = os.environ["GOOGLE_API_KEY"]
149
+ else:
150
+ api_key = st.session_state.api_keys.get("Google", "")
151
+ if not api_key:
152
+ raise ValueError(
153
+ "Google API key is required. Please set it in the app or as the GOOGLE_API_KEY environment variable.")
154
+ try:
155
+ configure(api_key=api_key) # Moved configure into try block
156
+ st.session_state.google_configured = True
157
+ except Exception as e:
158
+ raise ValueError(f"Error configuring Google API: {e}")
159
 
160
  generation_config = genai.GenerationConfig(
161
  temperature=st.session_state.advanced_options["temperature"],
 
163
  top_k=st.session_state.advanced_options["top_k"],
164
  max_output_tokens=st.session_state.advanced_options["max_output_tokens"]
165
  )
166
+ safety_settings = [
167
+ {
168
+ "category": "HARM_CATEGORY_HARASSMENT",
169
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
170
+ },
171
+ {
172
+ "category": "HARM_CATEGORY_HATE_SPEECH",
173
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
174
+ },
175
+ {
176
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
177
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
178
+ },
179
+ {
180
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
181
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
182
+ },
183
+ ]
184
+ return GenerativeModel(model_name=model, generation_config=generation_config, safety_settings=safety_settings)
185
  else:
186
  return OpenAI(
187
  base_url=config["base_url"],
 
192
  self._log_error(f"Client Init Failed: {str(e)}")
193
  return None
194
 
195
+ def _execute_generation(self, client, provider: str, model: str, prompt: Any) -> Dict[str, Any]: # Use Any for prompt type
196
+
197
  """Execute provider-specific generation with circuit breaker"""
198
  st.session_state.system_metrics["api_calls"] += 1
199
 
 
208
  return response.json()
209
  elif provider == "Google":
210
  try:
211
+ if isinstance(prompt, list): #Multimodal case
212
+
213
+ response = client.generate_content(prompt)
214
+
215
+ else:
216
+
217
+ response = client.generate_content(prompt)
218
+
219
+
220
  content = response.text
221
 
222
+ if st.session_state.generation_format == "json":
223
  try:
224
+ return json.loads(content)
225
  except json.JSONDecodeError:
226
+ return {"content": content,
227
+ "warning": "Could not parse response as valid JSON. Returning raw text."}
228
  else:
229
+ return {"content": content}
230
 
231
  except Exception as e:
232
  self._log_error(f"Google Generation Error: {str(e)}")
 
235
  completion = client.chat.completions.create(
236
  model=model,
237
  messages=[{"role": "user", "content": prompt}],
238
+ temperature=st.session_state.advanced_options["temperature"],
239
  max_tokens=st.session_state.advanced_options["max_output_tokens"]
240
  )
241
  st.session_state.system_metrics["tokens_used"] += completion.usage.total_tokens
242
  try:
243
  return json.loads(completion.choices[0].message.content)
244
  except json.JSONDecodeError:
245
+ return {"content": completion.choices[0].message.content,
246
+ "warning": "Could not parse response as valid JSON. Returning raw text."}
247
 
248
  def _failover_generation(self, prompt: str) -> Dict[str, Any]:
249
  """Enterprise failover to secondary providers"""
250
  for backup_provider in self.available_providers:
251
  if backup_provider != st.session_state.active_provider:
252
  try:
253
+ return self.generate(backup_provider, ..., prompt=prompt)
254
  except Exception:
255
  continue
256
  raise RuntimeError("All generation providers unavailable")
 
277
  return ""
278
 
279
  def _process_csv(self, file) -> str:
280
+ """Process CSV files and return as a string representation."""
281
+ try:
282
+ df = pd.read_csv(file)
283
+ column_names = df.columns.tolist()
284
+ data_types = [str(df[col].dtype) for col in df.columns]
285
+ schema_prompt = f"Column Names: {column_names}\nData Types: {data_types}"
286
+ st.session_state.csv_schema = schema_prompt
287
+ return df.to_string()
288
+ except Exception as e:
289
+ self._log_error(f"CSV Processing Error: {str(e)}")
290
+ return ""
 
 
 
291
 
292
  def _process_text(self, text: str) -> str:
293
  """Simple text passthrough processor"""
294
  return text
295
 
296
+ def _process_api(self, url: str, method="GET", headers: Optional[Dict[str, str]] = None,
297
+ data: Optional[Dict[str, Any]] = None) -> str:
298
  """Generic API endpoint processor with configurable methods and headers."""
299
  try:
300
  if method.upper() == "GET":
 
303
  response = requests.post(url, headers=headers or {}, json=data, timeout=10)
304
  else:
305
  raise ValueError("Unsupported HTTP method.")
306
+ response.raise_for_status()
307
 
308
  try:
309
+ return json.dumps(response.json(), indent=2)
310
  except json.JSONDecodeError:
311
+ return response.text
312
  except requests.exceptions.RequestException as e:
313
  self._log_error(f"API Processing Error: {str(e)}")
314
  return ""
 
325
  self._log_error(f"Database Processing Error: {str(e)}")
326
  return ""
327
 
328
+ def _process_image(self, image_file) -> list: #Returns a list
329
+ """Processes image files for multimodal generation (Google Gemini)"""
330
  try:
 
331
  image_data = image_file.read()
332
+ image_part = Part.from_data(image_data, mime_type=image_file.type) #Use Part for google
333
+ return [image_part] #Return a list with the image part as a Google Part object
 
334
 
335
  except Exception as e:
336
  self._log_error(f"Image Processing Error: {str(e)}")
337
+ return []
338
 
339
  # --- Enterprise Features ---
340
  def _log_error(self, message: str) -> None:
 
369
  return response.status_code == 200
370
  elif provider == "Google":
371
  try:
372
+ if not st.session_state.google_configured: #Check if google has been configured
373
+
374
+ api_key = st.session_state.api_keys.get("Google", "") #Get Key from session state
375
+
376
+ if not api_key: #If that is not set, check environment variable.
377
+ api_key = os.environ.get("GOOGLE_API_KEY")
378
+
379
  if not api_key:
380
+ return False #Cant test API if no API Key
381
 
382
  configure(api_key=api_key) #Configure API Key
383
  st.session_state.google_configured = True
384
+ #st.write("configuring key")
385
 
386
+ genai.GenerativeModel(model_name=self.PROVIDER_CONFIG["Google"]["models"][0]).generate_content("test") #Test a generation
387
+ return True
388
 
389
+ except Exception as e: #Catch any exceptions
390
  print(e)
391
  return False
392
 
 
427
  )
428
  st.session_state.active_model = model
429
 
430
+ # Advanced Options
431
+ if provider == "Google" or provider == "OpenAI":
432
  st.subheader("Advanced Generation Options")
433
+ st.session_state.advanced_options["temperature"] = st.slider("Temperature", min_value=0.0,
434
+ max_value=1.0,
435
+ value=st.session_state.advanced_options[
436
+ "temperature"], step=0.05,
437
+ help="Controls randomness. Lower values = more deterministic.")
438
  if provider == "Google":
439
+ st.session_state.advanced_options["top_p"] = st.slider("Top P", min_value=0.0, max_value=1.0,
440
+ value=st.session_state.advanced_options["top_p"],
441
+ step=0.05,
442
+ help="Nucleus sampling: Considers the most probable tokens.")
443
+ st.session_state.advanced_options["top_k"] = st.slider("Top K", min_value=1, max_value=100,
444
+ value=st.session_state.advanced_options["top_k"],
445
+ step=1,
446
+ help="Considers the top K most probable tokens.")
447
+
448
+ st.session_state.advanced_options["max_output_tokens"] = st.number_input("Max Output Tokens",
449
+ min_value=50, max_value=4096,
450
+ value=st.session_state.advanced_options[
451
+ "max_output_tokens"], step=50,
452
+ help="Maximum number of tokens in the generated output.")
453
+
454
+ st.session_state.generation_format = st.selectbox("Output Format", ["json", "text"],
455
+ help="Choose the desired output format.")
456
 
457
  # System monitoring
458
  if st.button("Run Health Check"):
459
  report = gen.health_check()
460
  st.json(report)
461
 
462
+
463
  def input_ui():
464
  """Creates the input method UI"""
465
+ input_method = st.selectbox("Input Method",
466
+ ["Text", "PDF", "Web URL", "CSV", "Image",
467
+ "Structured Prompt (Advanced)"]) # Add Image input, Add Structured Prompt (Advanced)
468
  input_content = None
469
+ additional_instructions = "" # For structured prompt
470
 
471
  if input_method == "Text":
472
  input_content = st.text_area("Enter Text", height=200)
 
483
  input_content = uploaded_file
484
  if "csv_schema" in st.session_state:
485
  st.write("Inferred CSV Schema:")
486
+ st.write(st.session_state.csv_schema)
487
 
488
  elif input_method == "Image":
489
  uploaded_file = st.file_uploader("Upload an Image file", type=["png", "jpg", "jpeg"])
 
493
  elif input_method == "Structured Prompt (Advanced)":
494
  st.subheader("Structured Prompt")
495
  input_content = st.text_area("Enter the base prompt/instructions", height=100)
496
+ additional_instructions = st.text_area("Specify constraints, data format, or other requirements:",
497
+ height=100)
498
+
499
+ return input_method, input_content, additional_instructions
500
 
 
501
 
502
  def main():
503
  """Enterprise-grade user interface"""
 
517
 
518
  provider_config_ui(gen)
519
 
520
+ input_method, input_content, additional_instructions = input_ui()
521
 
522
  if st.button("Generate Data"):
523
+ if input_content or input_method == "Structured Prompt (Advanced)":
524
  processed_input = None
525
 
526
  if input_method == "Text":
 
532
  elif input_method == "CSV":
533
  processed_input = gen._process_csv(input_content)
534
  elif input_method == "Image":
535
+ processed_input = gen._process_image(input_content) #This is a list now
536
+ if not processed_input: #If something went wrong with image processing, don't proceed
537
+ st.error("Error processing image.")
538
+ return
539
+
540
  elif input_method == "Structured Prompt (Advanced)":
541
+ processed_input = input_content + "\n" + additional_instructions
 
 
542
 
543
  if processed_input:
544
  try:
 
545
  if st.session_state.active_provider == "Google" and input_method == "Image":
546
+ prompt_parts = [input_content] + processed_input #Keeps text and images separate for google
547
+ result = gen.generate(st.session_state.active_provider, st.session_state.active_model, prompt_parts)
548
  else:
549
+ result = gen.generate(st.session_state.active_provider, st.session_state.active_model, processed_input)
550
 
551
  st.subheader("Generated Output:")
552
+ st.json(result)
553
  except Exception as e:
554
  st.error(f"Error during generation: {e}")
555
  else:
 
557
  else:
558
  st.warning("Please provide input data.")
559
 
 
 
560
  if __name__ == "__main__":
561
  main()