Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,446 +4,387 @@ import streamlit as st
|
|
4 |
import pdfplumber
|
5 |
import pandas as pd
|
6 |
import sqlalchemy
|
7 |
-
|
|
|
8 |
from typing import Any, Dict, List
|
9 |
|
10 |
-
# Provider clients
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
17 |
DEFAULT_TEMPERATURE = 0.1
|
18 |
-
GROQ_MODEL = "mixtral-8x7b-32768"
|
19 |
-
API_HEADERS_HEIGHT = 70 # Height for the API headers text area
|
20 |
|
21 |
|
22 |
-
class
|
23 |
"""
|
24 |
-
|
|
|
|
|
|
|
25 |
"""
|
26 |
def __init__(self) -> None:
|
27 |
self._setup_providers()
|
28 |
self._setup_input_handlers()
|
29 |
self._initialize_session_state()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def _setup_providers(self) -> None:
|
32 |
-
"""Configure available LLM providers and their
|
33 |
self.providers: Dict[str, Dict[str, Any]] = {
|
34 |
"Deepseek": {
|
35 |
-
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
36 |
"models": ["deepseek-chat"],
|
37 |
},
|
38 |
"OpenAI": {
|
39 |
-
"client": lambda key: OpenAI(api_key=key),
|
40 |
-
"models": ["gpt-4-turbo"],
|
41 |
},
|
42 |
"Groq": {
|
43 |
-
"client": lambda key: groq.Groq(api_key=key),
|
44 |
"models": [GROQ_MODEL],
|
45 |
},
|
46 |
"HuggingFace": {
|
47 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
48 |
"models": ["gpt2", "llama-2"],
|
49 |
},
|
50 |
-
"Google": {
|
51 |
-
"client": lambda key: self._configure_google_genai(key),
|
52 |
-
"models": ["gemini-pro"],
|
53 |
-
},
|
54 |
}
|
55 |
|
56 |
def _setup_input_handlers(self) -> None:
|
57 |
-
"""
|
58 |
self.input_handlers: Dict[str, Any] = {
|
59 |
-
"pdf": self.handle_pdf,
|
60 |
"text": self.handle_text,
|
|
|
61 |
"csv": self.handle_csv,
|
62 |
"api": self.handle_api,
|
63 |
"db": self.handle_db,
|
64 |
}
|
65 |
|
66 |
def _initialize_session_state(self) -> None:
|
67 |
-
"""Initialize Streamlit session state with default
|
68 |
-
|
69 |
-
"inputs": [],
|
70 |
-
"qa_data": [],
|
71 |
-
"processing": {"stage": "idle", "progress": 0, "errors": []},
|
72 |
"config": {
|
73 |
-
"provider": "
|
74 |
-
"model":
|
75 |
"temperature": DEFAULT_TEMPERATURE,
|
|
|
76 |
},
|
77 |
-
"api_key": "",
|
|
|
|
|
|
|
|
|
78 |
}
|
79 |
-
for key, value in
|
80 |
if key not in st.session_state:
|
81 |
st.session_state[key] = value
|
82 |
|
83 |
-
def
|
84 |
-
"""
|
85 |
-
|
86 |
-
|
87 |
-
return genai.GenerativeModel
|
88 |
-
except Exception as e:
|
89 |
-
st.error(f"Error configuring Google GenAI: {e}")
|
90 |
-
return None
|
91 |
|
92 |
-
#
|
93 |
-
def
|
94 |
-
"""
|
95 |
-
Extract text and images from a PDF file.
|
96 |
|
97 |
-
|
98 |
-
A list of dictionaries containing text, images, and metadata.
|
99 |
-
"""
|
100 |
try:
|
101 |
with pdfplumber.open(file) as pdf:
|
102 |
-
|
103 |
-
for
|
104 |
page_text = page.extract_text() or ""
|
105 |
-
|
106 |
-
|
107 |
-
"text": page_text,
|
108 |
-
"images": page_images,
|
109 |
-
"meta": {"type": "pdf", "page": i + 1},
|
110 |
-
})
|
111 |
-
return extracted_data
|
112 |
except Exception as e:
|
113 |
-
self.
|
114 |
-
return
|
115 |
|
116 |
-
def
|
117 |
-
"""Handle manual text input."""
|
118 |
-
return [{"text": text, "meta": {"type": "domain", "source": "manual"}}]
|
119 |
-
|
120 |
-
def handle_csv(self, file) -> List[Dict[str, Any]]:
|
121 |
-
"""Process a CSV file and format the data for Q&A generation."""
|
122 |
try:
|
123 |
df = pd.read_csv(file)
|
124 |
-
|
125 |
-
|
126 |
-
"text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
|
127 |
-
"meta": {"type": "csv", "columns": list(df.columns)},
|
128 |
-
}
|
129 |
-
for _, row in df.iterrows()
|
130 |
-
]
|
131 |
except Exception as e:
|
132 |
-
self.
|
133 |
-
return
|
134 |
|
135 |
-
def handle_api(self, config: Dict[str,
|
136 |
-
"""Fetch data from an API endpoint and format it for processing."""
|
137 |
try:
|
138 |
-
response = requests.get(config["url"], headers=config
|
139 |
response.raise_for_status()
|
140 |
-
return
|
141 |
-
|
142 |
-
|
143 |
-
}
|
144 |
-
|
145 |
-
|
146 |
-
return []
|
147 |
-
|
148 |
-
def handle_db(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
149 |
-
"""Connect to a database, execute a query, and format the results."""
|
150 |
try:
|
151 |
engine = sqlalchemy.create_engine(config["connection"])
|
152 |
with engine.connect() as conn:
|
153 |
result = conn.execute(sqlalchemy.text(config["query"]))
|
154 |
-
|
155 |
-
|
156 |
-
"text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
|
157 |
-
"meta": {"type": "db", "table": config.get("table", "")},
|
158 |
-
}
|
159 |
-
for row in result
|
160 |
-
]
|
161 |
except Exception as e:
|
162 |
-
self.
|
163 |
-
return
|
164 |
-
|
165 |
-
def
|
166 |
-
"""
|
167 |
-
|
168 |
-
for
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
return images
|
185 |
|
186 |
-
|
187 |
-
def generate(self, api_key: str) -> bool:
|
188 |
"""
|
189 |
-
Generate
|
190 |
-
|
191 |
-
Iterates over all the input data, calls the appropriate inference method,
|
192 |
-
and aggregates the generated Q&A pairs into session state.
|
193 |
"""
|
|
|
194 |
if not api_key:
|
195 |
-
|
196 |
return False
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
# Initialize the client
|
204 |
-
if provider_name == "Google":
|
205 |
-
client = client_initializer(api_key)
|
206 |
-
if not client:
|
207 |
-
return False
|
208 |
-
else:
|
209 |
-
client = client_initializer(api_key)
|
210 |
-
|
211 |
-
for i, input_data in enumerate(st.session_state.inputs):
|
212 |
-
st.session_state.processing["progress"] = (i + 1) / len(st.session_state.inputs)
|
213 |
-
st.write("--- Input Data ---")
|
214 |
-
st.write(input_data["text"])
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
response = self._standard_inference(client, input_data)
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
229 |
|
|
|
|
|
230 |
return True
|
231 |
-
|
232 |
except Exception as e:
|
233 |
-
self.
|
234 |
return False
|
235 |
|
236 |
-
def _standard_inference(self, client: Any,
|
237 |
-
"""
|
|
|
|
|
238 |
try:
|
239 |
-
|
240 |
-
model=
|
241 |
-
messages=[{"role": "user", "content":
|
242 |
-
temperature=
|
243 |
)
|
|
|
244 |
except Exception as e:
|
245 |
-
self.
|
246 |
return None
|
247 |
|
248 |
-
def _huggingface_inference(self, client: Dict[str, Any],
|
249 |
-
"""
|
|
|
|
|
250 |
try:
|
251 |
response = requests.post(
|
252 |
-
HF_API_URL +
|
253 |
headers=client["headers"],
|
254 |
-
json={"inputs":
|
|
|
255 |
)
|
256 |
response.raise_for_status()
|
257 |
return response.json()
|
258 |
-
except requests.exceptions.RequestException as e:
|
259 |
-
self._log_error(f"Hugging Face Inference Error: {e}")
|
260 |
-
return None
|
261 |
-
|
262 |
-
def _google_inference(self, client: Any, input_data: Dict[str, Any]) -> Any:
|
263 |
-
"""Perform inference using the Google Generative AI API."""
|
264 |
-
try:
|
265 |
-
model = client(st.session_state.config["model"])
|
266 |
-
response = model.generate_content(
|
267 |
-
self._build_prompt(input_data),
|
268 |
-
generation_config=genai.types.GenerationConfig(
|
269 |
-
temperature=st.session_state.config["temperature"]
|
270 |
-
),
|
271 |
-
)
|
272 |
-
return response
|
273 |
except Exception as e:
|
274 |
-
self.
|
275 |
return None
|
276 |
|
277 |
-
|
278 |
-
def _build_prompt(self, input_data: Dict[str, Any]) -> str:
|
279 |
-
"""
|
280 |
-
Build the prompt for the LLM based on the input data.
|
281 |
-
|
282 |
-
The prompt instructs the LLM to extract 3 Q&A pairs in JSON format.
|
283 |
-
"""
|
284 |
-
base_prompt = (
|
285 |
-
"You are an expert in extracting question and answer pairs from documents. "
|
286 |
-
"Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries.\n"
|
287 |
-
"Each dictionary must have the keys 'question' and 'answer'.\n"
|
288 |
-
"The 'question' should be clear and concise, and the 'answer' should directly answer the question "
|
289 |
-
"using only information from the provided data. Do not hallucinate or invent information.\n"
|
290 |
-
"Answer using the exact information from the document, not external knowledge.\n"
|
291 |
-
"Example JSON Output:\n"
|
292 |
-
'[{"question": "What is the capital of France?", "answer": "The capital of France is Paris."}, '
|
293 |
-
'{"question": "What is the highest mountain in the world?", "answer": "The highest mountain in the world is Mount Everest."}, '
|
294 |
-
'{"question": "What is the chemical symbol for gold?", "answer": "The chemical symbol for gold is Au."}]\n'
|
295 |
-
"Now, generate 3 Q&A pairs from this data:\n"
|
296 |
-
)
|
297 |
-
data_type = input_data["meta"].get("type", "text")
|
298 |
-
if data_type == "csv":
|
299 |
-
return base_prompt + "Data:\n" + input_data["text"]
|
300 |
-
elif data_type == "api":
|
301 |
-
return base_prompt + "API response:\n" + input_data["text"]
|
302 |
-
return base_prompt + input_data["text"]
|
303 |
-
|
304 |
-
# --- RESPONSE PARSING ---
|
305 |
-
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
|
306 |
"""
|
307 |
-
Parse the LLM response into a
|
308 |
-
|
309 |
-
Expects the response to be a JSON formatted string.
|
310 |
"""
|
311 |
try:
|
312 |
-
response_text = ""
|
313 |
if provider == "HuggingFace":
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
self._log_error("Empty or malformed response from LLM.")
|
320 |
-
return []
|
321 |
-
response_text = response.choices[0].message.content
|
322 |
-
|
323 |
-
try:
|
324 |
-
json_output = json.loads(response_text)
|
325 |
-
except json.JSONDecodeError as e:
|
326 |
-
self._log_error(f"JSON Parse Error: {e}. Raw Response: {response_text}")
|
327 |
-
return []
|
328 |
-
|
329 |
-
if isinstance(json_output, list):
|
330 |
-
qa_pairs = json_output
|
331 |
-
elif isinstance(json_output, dict) and "questionList" in json_output:
|
332 |
-
qa_pairs = json_output["questionList"]
|
333 |
else:
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
return []
|
340 |
-
|
341 |
-
for pair in qa_pairs:
|
342 |
-
if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
|
343 |
-
self._log_error(f"Invalid QA pair structure: {pair}")
|
344 |
-
return []
|
345 |
-
|
346 |
-
return qa_pairs
|
347 |
-
|
348 |
except Exception as e:
|
349 |
-
self.
|
350 |
-
return
|
351 |
|
352 |
-
def _log_error(self, message: str) -> None:
|
353 |
-
"""Log an error message to the session state and display it."""
|
354 |
-
st.session_state.processing["errors"].append(message)
|
355 |
-
st.error(message)
|
356 |
|
|
|
357 |
|
358 |
-
|
359 |
-
|
360 |
-
"""Create the input sidebar in the Streamlit UI."""
|
361 |
with st.sidebar:
|
362 |
-
st.header("
|
363 |
-
provider = st.selectbox("Provider", list(generator.providers.keys()))
|
364 |
-
st.session_state.config["provider"] = provider
|
365 |
provider_cfg = generator.providers[provider]
|
366 |
|
367 |
-
|
368 |
-
st.session_state["api_key"] = api_key
|
369 |
-
|
370 |
-
model = st.selectbox("Model", provider_cfg["models"])
|
371 |
st.session_state.config["model"] = model
|
372 |
|
373 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
374 |
st.session_state.config["temperature"] = temperature
|
375 |
|
376 |
-
|
377 |
-
st.
|
378 |
-
input_type = st.selectbox("Input Type", list(generator.input_handlers.keys()))
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
headers = {}
|
394 |
-
|
395 |
-
|
396 |
headers = json.loads(api_headers)
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
st.download_button("Export CSV", df.to_csv(index=False), "synthetic_data.csv")
|
439 |
-
|
440 |
|
441 |
def main() -> None:
|
442 |
-
"
|
443 |
-
generator =
|
444 |
-
|
445 |
-
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
if __name__ == "__main__":
|
449 |
main()
|
|
|
4 |
import pdfplumber
|
5 |
import pandas as pd
|
6 |
import sqlalchemy
|
7 |
+
import time
|
8 |
+
import concurrent.futures
|
9 |
from typing import Any, Dict, List
|
10 |
|
11 |
+
# Provider clients (make sure you have these installed)
|
12 |
+
try:
|
13 |
+
from openai import OpenAI
|
14 |
+
except ImportError:
|
15 |
+
OpenAI = None
|
16 |
|
17 |
+
try:
|
18 |
+
import groq
|
19 |
+
except ImportError:
|
20 |
+
groq = None
|
21 |
+
|
22 |
+
# Hugging Face inference URL
|
23 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
24 |
DEFAULT_TEMPERATURE = 0.1
|
25 |
+
GROQ_MODEL = "mixtral-8x7b-32768"
|
|
|
26 |
|
27 |
|
28 |
+
class AdvancedSyntheticDataGenerator:
|
29 |
"""
|
30 |
+
Advanced Synthetic Data Generator
|
31 |
+
|
32 |
+
This class handles multiple input sources, advanced prompt engineering, and
|
33 |
+
supports multiple LLM providers to generate synthetic data.
|
34 |
"""
|
35 |
def __init__(self) -> None:
|
36 |
self._setup_providers()
|
37 |
self._setup_input_handlers()
|
38 |
self._initialize_session_state()
|
39 |
+
# A customizable prompt template (you can modify it via the UI)
|
40 |
+
self.custom_prompt_template = (
|
41 |
+
"You are an expert synthetic data generator. "
|
42 |
+
"Given the data below and following the instructions provided, generate high-quality, diverse synthetic data. "
|
43 |
+
"Ensure the output adheres to the specified format.\n\n"
|
44 |
+
"-------------------------\n"
|
45 |
+
"Data:\n{data}\n\n"
|
46 |
+
"Instructions:\n{instructions}\n\n"
|
47 |
+
"Output Format: {format}\n"
|
48 |
+
"-------------------------\n"
|
49 |
+
)
|
50 |
|
51 |
def _setup_providers(self) -> None:
|
52 |
+
"""Configure available LLM providers and their initialization routines."""
|
53 |
self.providers: Dict[str, Dict[str, Any]] = {
|
54 |
"Deepseek": {
|
55 |
+
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
|
56 |
"models": ["deepseek-chat"],
|
57 |
},
|
58 |
"OpenAI": {
|
59 |
+
"client": lambda key: OpenAI(api_key=key) if OpenAI else None,
|
60 |
+
"models": ["gpt-4-turbo", "gpt-3.5-turbo"],
|
61 |
},
|
62 |
"Groq": {
|
63 |
+
"client": lambda key: groq.Groq(api_key=key) if groq else None,
|
64 |
"models": [GROQ_MODEL],
|
65 |
},
|
66 |
"HuggingFace": {
|
67 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
68 |
"models": ["gpt2", "llama-2"],
|
69 |
},
|
|
|
|
|
|
|
|
|
70 |
}
|
71 |
|
72 |
def _setup_input_handlers(self) -> None:
|
73 |
+
"""Register handlers for different input data types."""
|
74 |
self.input_handlers: Dict[str, Any] = {
|
|
|
75 |
"text": self.handle_text,
|
76 |
+
"pdf": self.handle_pdf,
|
77 |
"csv": self.handle_csv,
|
78 |
"api": self.handle_api,
|
79 |
"db": self.handle_db,
|
80 |
}
|
81 |
|
82 |
def _initialize_session_state(self) -> None:
|
83 |
+
"""Initialize Streamlit session state with default configuration."""
|
84 |
+
defaults = {
|
|
|
|
|
|
|
85 |
"config": {
|
86 |
+
"provider": "OpenAI",
|
87 |
+
"model": "gpt-4-turbo",
|
88 |
"temperature": DEFAULT_TEMPERATURE,
|
89 |
+
"output_format": "plain_text", # Options: plain_text, json, csv
|
90 |
},
|
91 |
+
"api_key": "",
|
92 |
+
"inputs": [], # A list to store input sources
|
93 |
+
"instructions": "", # Custom instructions for data generation
|
94 |
+
"synthetic_data": "", # The generated output
|
95 |
+
"error_logs": [], # Any errors that occur during processing
|
96 |
}
|
97 |
+
for key, value in defaults.items():
|
98 |
if key not in st.session_state:
|
99 |
st.session_state[key] = value
|
100 |
|
101 |
+
def log_error(self, message: str) -> None:
|
102 |
+
"""Log an error message both to the session state and in the UI."""
|
103 |
+
st.session_state.error_logs.append(message)
|
104 |
+
st.error(message)
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
# ===== INPUT HANDLERS =====
|
107 |
+
def handle_text(self, text: str) -> Dict[str, Any]:
|
108 |
+
return {"data": text, "source": "text"}
|
|
|
109 |
|
110 |
+
def handle_pdf(self, file) -> Dict[str, Any]:
|
|
|
|
|
111 |
try:
|
112 |
with pdfplumber.open(file) as pdf:
|
113 |
+
full_text = ""
|
114 |
+
for page in pdf.pages:
|
115 |
page_text = page.extract_text() or ""
|
116 |
+
full_text += page_text + "\n"
|
117 |
+
return {"data": full_text, "source": "pdf"}
|
|
|
|
|
|
|
|
|
|
|
118 |
except Exception as e:
|
119 |
+
self.log_error(f"PDF Processing Error: {e}")
|
120 |
+
return {"data": "", "source": "pdf"}
|
121 |
|
122 |
+
def handle_csv(self, file) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
123 |
try:
|
124 |
df = pd.read_csv(file)
|
125 |
+
# For simplicity, we convert the dataframe to JSON.
|
126 |
+
return {"data": df.to_json(orient="records"), "source": "csv"}
|
|
|
|
|
|
|
|
|
|
|
127 |
except Exception as e:
|
128 |
+
self.log_error(f"CSV Processing Error: {e}")
|
129 |
+
return {"data": "", "source": "csv"}
|
130 |
|
131 |
+
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
132 |
try:
|
133 |
+
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
|
134 |
response.raise_for_status()
|
135 |
+
return {"data": json.dumps(response.json()), "source": "api"}
|
136 |
+
except Exception as e:
|
137 |
+
self.log_error(f"API Processing Error: {e}")
|
138 |
+
return {"data": "", "source": "api"}
|
139 |
+
|
140 |
+
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
141 |
try:
|
142 |
engine = sqlalchemy.create_engine(config["connection"])
|
143 |
with engine.connect() as conn:
|
144 |
result = conn.execute(sqlalchemy.text(config["query"]))
|
145 |
+
rows = [dict(row) for row in result]
|
146 |
+
return {"data": json.dumps(rows), "source": "db"}
|
|
|
|
|
|
|
|
|
|
|
147 |
except Exception as e:
|
148 |
+
self.log_error(f"Database Processing Error: {e}")
|
149 |
+
return {"data": "", "source": "db"}
|
150 |
+
|
151 |
+
def aggregate_inputs(self) -> str:
|
152 |
+
"""Combine all input sources into a single data string."""
|
153 |
+
aggregated_data = ""
|
154 |
+
for item in st.session_state.inputs:
|
155 |
+
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
|
156 |
+
aggregated_data += item.get("data", "") + "\n\n"
|
157 |
+
return aggregated_data.strip()
|
158 |
+
|
159 |
+
def build_prompt(self) -> str:
|
160 |
+
"""
|
161 |
+
Build the complete prompt by combining the aggregated input data with
|
162 |
+
custom instructions and the desired output format.
|
163 |
+
"""
|
164 |
+
aggregated_data = self.aggregate_inputs()
|
165 |
+
instructions = st.session_state.instructions or "Generate diverse, coherent synthetic data."
|
166 |
+
output_format = st.session_state.config.get("output_format", "plain_text")
|
167 |
+
return self.custom_prompt_template.format(
|
168 |
+
data=aggregated_data, instructions=instructions, format=output_format
|
169 |
+
)
|
|
|
170 |
|
171 |
+
def generate_synthetic_data(self) -> bool:
|
|
|
172 |
"""
|
173 |
+
Generate synthetic data by sending the built prompt to the selected LLM provider.
|
174 |
+
Returns True if generation succeeds.
|
|
|
|
|
175 |
"""
|
176 |
+
api_key = st.session_state.api_key
|
177 |
if not api_key:
|
178 |
+
self.log_error("API key is missing!")
|
179 |
return False
|
180 |
|
181 |
+
provider_name = st.session_state.config["provider"]
|
182 |
+
provider_cfg = self.providers.get(provider_name)
|
183 |
+
if not provider_cfg:
|
184 |
+
self.log_error(f"Provider {provider_name} is not configured.")
|
185 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
client_initializer = provider_cfg["client"]
|
188 |
+
client = client_initializer(api_key)
|
189 |
+
model = st.session_state.config["model"]
|
190 |
+
temperature = st.session_state.config["temperature"]
|
191 |
+
prompt = self.build_prompt()
|
|
|
192 |
|
193 |
+
st.info(f"Using provider {provider_name} with model {model} at temperature {temperature:.2f}")
|
194 |
+
# (Optionally) simulate asynchronous processing with a thread pool if needed.
|
195 |
+
try:
|
196 |
+
if provider_name == "HuggingFace":
|
197 |
+
response = self._huggingface_inference(client, prompt, model)
|
198 |
+
else:
|
199 |
+
response = self._standard_inference(client, prompt, model, temperature)
|
200 |
|
201 |
+
synthetic_data = self._parse_response(response, provider_name)
|
202 |
+
st.session_state.synthetic_data = synthetic_data
|
203 |
return True
|
|
|
204 |
except Exception as e:
|
205 |
+
self.log_error(f"Generation failed: {e}")
|
206 |
return False
|
207 |
|
208 |
+
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
|
209 |
+
"""
|
210 |
+
Inference method for providers using an OpenAI-compatible API.
|
211 |
+
"""
|
212 |
try:
|
213 |
+
result = client.chat.completions.create(
|
214 |
+
model=model,
|
215 |
+
messages=[{"role": "user", "content": prompt}],
|
216 |
+
temperature=temperature,
|
217 |
)
|
218 |
+
return result
|
219 |
except Exception as e:
|
220 |
+
self.log_error(f"Standard Inference Error: {e}")
|
221 |
return None
|
222 |
|
223 |
+
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
|
224 |
+
"""
|
225 |
+
Inference method for the Hugging Face Inference API.
|
226 |
+
"""
|
227 |
try:
|
228 |
response = requests.post(
|
229 |
+
HF_API_URL + model,
|
230 |
headers=client["headers"],
|
231 |
+
json={"inputs": prompt},
|
232 |
+
timeout=30,
|
233 |
)
|
234 |
response.raise_for_status()
|
235 |
return response.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
except Exception as e:
|
237 |
+
self.log_error(f"HuggingFace Inference Error: {e}")
|
238 |
return None
|
239 |
|
240 |
+
def _parse_response(self, response: Any, provider: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
"""
|
242 |
+
Parse the LLM response into a synthetic data string.
|
|
|
|
|
243 |
"""
|
244 |
try:
|
|
|
245 |
if provider == "HuggingFace":
|
246 |
+
if isinstance(response, list) and "generated_text" in response[0]:
|
247 |
+
return response[0]["generated_text"]
|
248 |
+
else:
|
249 |
+
self.log_error("Unexpected HuggingFace response format.")
|
250 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
else:
|
252 |
+
if response and hasattr(response, "choices") and response.choices:
|
253 |
+
return response.choices[0].message.content
|
254 |
+
else:
|
255 |
+
self.log_error("Unexpected response format.")
|
256 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
except Exception as e:
|
258 |
+
self.log_error(f"Response Parsing Error: {e}")
|
259 |
+
return ""
|
260 |
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
# ===== ADVANCED UI COMPONENTS =====
|
263 |
|
264 |
+
def advanced_config_ui(generator: AdvancedSyntheticDataGenerator):
|
265 |
+
"""Advanced configuration options in the sidebar."""
|
|
|
266 |
with st.sidebar:
|
267 |
+
st.header("Advanced Configuration")
|
268 |
+
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
|
269 |
+
st.session_state.config["provider"] = provider
|
270 |
provider_cfg = generator.providers[provider]
|
271 |
|
272 |
+
model = st.selectbox("Select Model", provider_cfg["models"])
|
|
|
|
|
|
|
273 |
st.session_state.config["model"] = model
|
274 |
|
275 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
276 |
st.session_state.config["temperature"] = temperature
|
277 |
|
278 |
+
output_format = st.radio("Output Format", ["plain_text", "json", "csv"])
|
279 |
+
st.session_state.config["output_format"] = output_format
|
|
|
280 |
|
281 |
+
api_key = st.text_input(f"{provider} API Key", type="password")
|
282 |
+
st.session_state.api_key = api_key
|
283 |
+
|
284 |
+
instructions = st.text_area("Custom Instructions",
|
285 |
+
"Generate diverse, coherent synthetic data based on the input sources.",
|
286 |
+
height=100)
|
287 |
+
st.session_state.instructions = instructions
|
288 |
+
|
289 |
+
def advanced_input_ui(generator: AdvancedSyntheticDataGenerator):
|
290 |
+
"""UI for adding input sources using tabs."""
|
291 |
+
st.header("Input Data Sources")
|
292 |
+
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
|
293 |
+
|
294 |
+
with tabs[0]:
|
295 |
+
text_input = st.text_area("Enter text input", height=150)
|
296 |
+
if st.button("Add Text Input", key="text_input"):
|
297 |
+
if text_input.strip():
|
298 |
+
st.session_state.inputs.append(generator.handle_text(text_input))
|
299 |
+
st.success("Text input added!")
|
300 |
+
|
301 |
+
with tabs[1]:
|
302 |
+
pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
|
303 |
+
if pdf_file is not None:
|
304 |
+
st.session_state.inputs.append(generator.handle_pdf(pdf_file))
|
305 |
+
st.success("PDF input added!")
|
306 |
+
|
307 |
+
with tabs[2]:
|
308 |
+
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
309 |
+
if csv_file is not None:
|
310 |
+
st.session_state.inputs.append(generator.handle_csv(csv_file))
|
311 |
+
st.success("CSV input added!")
|
312 |
+
|
313 |
+
with tabs[3]:
|
314 |
+
api_url = st.text_input("API Endpoint URL")
|
315 |
+
api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
|
316 |
+
if st.button("Add API Input", key="api_input"):
|
317 |
headers = {}
|
318 |
+
try:
|
319 |
+
if api_headers:
|
320 |
headers = json.loads(api_headers)
|
321 |
+
except Exception as e:
|
322 |
+
generator.log_error(f"Invalid JSON for API Headers: {e}")
|
323 |
+
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
|
324 |
+
st.success("API input added!")
|
325 |
+
|
326 |
+
with tabs[4]:
|
327 |
+
db_conn = st.text_input("Database Connection String")
|
328 |
+
db_query = st.text_area("Database Query", height=100)
|
329 |
+
if st.button("Add Database Input", key="db_input"):
|
330 |
+
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
|
331 |
+
st.success("Database input added!")
|
332 |
+
|
333 |
+
def advanced_output_ui(generator: AdvancedSyntheticDataGenerator):
|
334 |
+
"""Display the generated synthetic data with various output options."""
|
335 |
+
st.header("Synthetic Data Output")
|
336 |
+
if st.session_state.synthetic_data:
|
337 |
+
output_format = st.session_state.config.get("output_format", "plain_text")
|
338 |
+
if output_format == "json":
|
339 |
+
try:
|
340 |
+
json_output = json.loads(st.session_state.synthetic_data)
|
341 |
+
st.json(json_output)
|
342 |
+
except Exception:
|
343 |
+
st.text_area("Output", st.session_state.synthetic_data, height=300)
|
344 |
+
else:
|
345 |
+
st.text_area("Output", st.session_state.synthetic_data, height=300)
|
346 |
+
st.download_button("Download Output", st.session_state.synthetic_data,
|
347 |
+
file_name="synthetic_data.txt", mime="text/plain")
|
348 |
+
else:
|
349 |
+
st.info("No synthetic data generated yet.")
|
350 |
+
|
351 |
+
def advanced_logs_ui():
|
352 |
+
"""Display error logs and debug information in an expandable section."""
|
353 |
+
with st.expander("Error Logs & Debug Info", expanded=False):
|
354 |
+
if st.session_state.error_logs:
|
355 |
+
for log in st.session_state.error_logs:
|
356 |
+
st.write(log)
|
357 |
+
else:
|
358 |
+
st.write("No logs yet.")
|
359 |
+
|
360 |
+
|
361 |
+
# ===== MAIN APPLICATION =====
|
|
|
|
|
362 |
|
363 |
def main() -> None:
|
364 |
+
st.set_page_config(page_title="Advanced Synthetic Data Generator", layout="wide")
|
365 |
+
generator = AdvancedSyntheticDataGenerator()
|
366 |
+
advanced_config_ui(generator)
|
367 |
+
|
368 |
+
# Create main tabs for Input, Output, and Logs
|
369 |
+
main_tabs = st.tabs(["Input", "Output", "Logs"])
|
370 |
+
|
371 |
+
with main_tabs[0]:
|
372 |
+
advanced_input_ui(generator)
|
373 |
+
if st.button("Clear Inputs"):
|
374 |
+
st.session_state.inputs = []
|
375 |
+
st.success("Inputs cleared!")
|
376 |
+
|
377 |
+
with main_tabs[1]:
|
378 |
+
if st.button("Generate Synthetic Data"):
|
379 |
+
with st.spinner("Generating synthetic data..."):
|
380 |
+
if generator.generate_synthetic_data():
|
381 |
+
st.success("Data generated successfully!")
|
382 |
+
else:
|
383 |
+
st.error("Data generation failed. Check logs for details.")
|
384 |
+
advanced_output_ui(generator)
|
385 |
+
|
386 |
+
with main_tabs[2]:
|
387 |
+
advanced_logs_ui()
|
388 |
|
389 |
if __name__ == "__main__":
|
390 |
main()
|