Spaces:
Sleeping
Sleeping
import json | |
import requests | |
import streamlit as st | |
import pdfplumber | |
import pandas as pd | |
import sqlalchemy | |
import time | |
import concurrent.futures | |
from typing import Any, Dict, List | |
# Provider clients (make sure you have these installed) | |
try: | |
from openai import OpenAI | |
except ImportError: | |
OpenAI = None | |
try: | |
import groq | |
except ImportError: | |
groq = None | |
# Hugging Face inference URL | |
HF_API_URL = "https://api-inference.huggingface.co/models/" | |
DEFAULT_TEMPERATURE = 0.1 | |
GROQ_MODEL = "mixtral-8x7b-32768" | |
class AdvancedSyntheticDataGenerator: | |
""" | |
Advanced Synthetic Data Generator | |
This class handles multiple input sources, advanced prompt engineering, and | |
supports multiple LLM providers to generate synthetic data. | |
""" | |
def __init__(self) -> None: | |
self._setup_providers() | |
self._setup_input_handlers() | |
self._initialize_session_state() | |
# A customizable prompt template (you can modify it via the UI) | |
self.custom_prompt_template = ( | |
"You are an expert synthetic data generator. " | |
"Given the data below and following the instructions provided, generate high-quality, diverse synthetic data. " | |
"Ensure the output adheres to the specified format.\n\n" | |
"-------------------------\n" | |
"Data:\n{data}\n\n" | |
"Instructions:\n{instructions}\n\n" | |
"Output Format: {format}\n" | |
"-------------------------\n" | |
) | |
def _setup_providers(self) -> None: | |
"""Configure available LLM providers and their initialization routines.""" | |
self.providers: Dict[str, Dict[str, Any]] = { | |
"Deepseek": { | |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None, | |
"models": ["deepseek-chat"], | |
}, | |
"OpenAI": { | |
"client": lambda key: OpenAI(api_key=key) if OpenAI else None, | |
"models": ["gpt-4-turbo", "gpt-3.5-turbo"], | |
}, | |
"Groq": { | |
"client": lambda key: groq.Groq(api_key=key) if groq else None, | |
"models": [GROQ_MODEL], | |
}, | |
"HuggingFace": { | |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}}, | |
"models": ["gpt2", "llama-2"], | |
}, | |
} | |
def _setup_input_handlers(self) -> None: | |
"""Register handlers for different input data types.""" | |
self.input_handlers: Dict[str, Any] = { | |
"text": self.handle_text, | |
"pdf": self.handle_pdf, | |
"csv": self.handle_csv, | |
"api": self.handle_api, | |
"db": self.handle_db, | |
} | |
def _initialize_session_state(self) -> None: | |
"""Initialize Streamlit session state with default configuration.""" | |
defaults = { | |
"config": { | |
"provider": "OpenAI", | |
"model": "gpt-4-turbo", | |
"temperature": DEFAULT_TEMPERATURE, | |
"output_format": "plain_text", # Options: plain_text, json, csv | |
}, | |
"api_key": "", | |
"inputs": [], # A list to store input sources | |
"instructions": "", # Custom instructions for data generation | |
"synthetic_data": "", # The generated output | |
"error_logs": [], # Any errors that occur during processing | |
} | |
for key, value in defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
def log_error(self, message: str) -> None: | |
"""Log an error message both to the session state and in the UI.""" | |
st.session_state.error_logs.append(message) | |
st.error(message) | |
# ===== INPUT HANDLERS ===== | |
def handle_text(self, text: str) -> Dict[str, Any]: | |
return {"data": text, "source": "text"} | |
def handle_pdf(self, file) -> Dict[str, Any]: | |
try: | |
with pdfplumber.open(file) as pdf: | |
full_text = "" | |
for page in pdf.pages: | |
page_text = page.extract_text() or "" | |
full_text += page_text + "\n" | |
return {"data": full_text, "source": "pdf"} | |
except Exception as e: | |
self.log_error(f"PDF Processing Error: {e}") | |
return {"data": "", "source": "pdf"} | |
def handle_csv(self, file) -> Dict[str, Any]: | |
try: | |
df = pd.read_csv(file) | |
# For simplicity, we convert the dataframe to JSON. | |
return {"data": df.to_json(orient="records"), "source": "csv"} | |
except Exception as e: | |
self.log_error(f"CSV Processing Error: {e}") | |
return {"data": "", "source": "csv"} | |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]: | |
try: | |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10) | |
response.raise_for_status() | |
return {"data": json.dumps(response.json()), "source": "api"} | |
except Exception as e: | |
self.log_error(f"API Processing Error: {e}") | |
return {"data": "", "source": "api"} | |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]: | |
try: | |
engine = sqlalchemy.create_engine(config["connection"]) | |
with engine.connect() as conn: | |
result = conn.execute(sqlalchemy.text(config["query"])) | |
rows = [dict(row) for row in result] | |
return {"data": json.dumps(rows), "source": "db"} | |
except Exception as e: | |
self.log_error(f"Database Processing Error: {e}") | |
return {"data": "", "source": "db"} | |
def aggregate_inputs(self) -> str: | |
"""Combine all input sources into a single data string.""" | |
aggregated_data = "" | |
for item in st.session_state.inputs: | |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n" | |
aggregated_data += item.get("data", "") + "\n\n" | |
return aggregated_data.strip() | |
def build_prompt(self) -> str: | |
""" | |
Build the complete prompt by combining the aggregated input data with | |
custom instructions and the desired output format. | |
""" | |
aggregated_data = self.aggregate_inputs() | |
instructions = st.session_state.instructions or "Generate diverse, coherent synthetic data." | |
output_format = st.session_state.config.get("output_format", "plain_text") | |
return self.custom_prompt_template.format( | |
data=aggregated_data, instructions=instructions, format=output_format | |
) | |
def generate_synthetic_data(self) -> bool: | |
""" | |
Generate synthetic data by sending the built prompt to the selected LLM provider. | |
Returns True if generation succeeds. | |
""" | |
api_key = st.session_state.api_key | |
if not api_key: | |
self.log_error("API key is missing!") | |
return False | |
provider_name = st.session_state.config["provider"] | |
provider_cfg = self.providers.get(provider_name) | |
if not provider_cfg: | |
self.log_error(f"Provider {provider_name} is not configured.") | |
return False | |
client_initializer = provider_cfg["client"] | |
client = client_initializer(api_key) | |
model = st.session_state.config["model"] | |
temperature = st.session_state.config["temperature"] | |
prompt = self.build_prompt() | |
st.info(f"Using provider {provider_name} with model {model} at temperature {temperature:.2f}") | |
# (Optionally) simulate asynchronous processing with a thread pool if needed. | |
try: | |
if provider_name == "HuggingFace": | |
response = self._huggingface_inference(client, prompt, model) | |
else: | |
response = self._standard_inference(client, prompt, model, temperature) | |
synthetic_data = self._parse_response(response, provider_name) | |
st.session_state.synthetic_data = synthetic_data | |
return True | |
except Exception as e: | |
self.log_error(f"Generation failed: {e}") | |
return False | |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any: | |
""" | |
Inference method for providers using an OpenAI-compatible API. | |
""" | |
try: | |
result = client.chat.completions.create( | |
model=model, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=temperature, | |
) | |
return result | |
except Exception as e: | |
self.log_error(f"Standard Inference Error: {e}") | |
return None | |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any: | |
""" | |
Inference method for the Hugging Face Inference API. | |
""" | |
try: | |
response = requests.post( | |
HF_API_URL + model, | |
headers=client["headers"], | |
json={"inputs": prompt}, | |
timeout=30, | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.log_error(f"HuggingFace Inference Error: {e}") | |
return None | |
def _parse_response(self, response: Any, provider: str) -> str: | |
""" | |
Parse the LLM response into a synthetic data string. | |
""" | |
try: | |
if provider == "HuggingFace": | |
if isinstance(response, list) and "generated_text" in response[0]: | |
return response[0]["generated_text"] | |
else: | |
self.log_error("Unexpected HuggingFace response format.") | |
return "" | |
else: | |
if response and hasattr(response, "choices") and response.choices: | |
return response.choices[0].message.content | |
else: | |
self.log_error("Unexpected response format.") | |
return "" | |
except Exception as e: | |
self.log_error(f"Response Parsing Error: {e}") | |
return "" | |
# ===== ADVANCED UI COMPONENTS ===== | |
def advanced_config_ui(generator: AdvancedSyntheticDataGenerator): | |
"""Advanced configuration options in the sidebar.""" | |
with st.sidebar: | |
st.header("Advanced Configuration") | |
provider = st.selectbox("Select Provider", list(generator.providers.keys())) | |
st.session_state.config["provider"] = provider | |
provider_cfg = generator.providers[provider] | |
model = st.selectbox("Select Model", provider_cfg["models"]) | |
st.session_state.config["model"] = model | |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) | |
st.session_state.config["temperature"] = temperature | |
output_format = st.radio("Output Format", ["plain_text", "json", "csv"]) | |
st.session_state.config["output_format"] = output_format | |
api_key = st.text_input(f"{provider} API Key", type="password") | |
st.session_state.api_key = api_key | |
instructions = st.text_area("Custom Instructions", | |
"Generate diverse, coherent synthetic data based on the input sources.", | |
height=100) | |
st.session_state.instructions = instructions | |
def advanced_input_ui(generator: AdvancedSyntheticDataGenerator): | |
"""UI for adding input sources using tabs.""" | |
st.header("Input Data Sources") | |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"]) | |
with tabs[0]: | |
text_input = st.text_area("Enter text input", height=150) | |
if st.button("Add Text Input", key="text_input"): | |
if text_input.strip(): | |
st.session_state.inputs.append(generator.handle_text(text_input)) | |
st.success("Text input added!") | |
with tabs[1]: | |
pdf_file = st.file_uploader("Upload PDF", type=["pdf"]) | |
if pdf_file is not None: | |
st.session_state.inputs.append(generator.handle_pdf(pdf_file)) | |
st.success("PDF input added!") | |
with tabs[2]: | |
csv_file = st.file_uploader("Upload CSV", type=["csv"]) | |
if csv_file is not None: | |
st.session_state.inputs.append(generator.handle_csv(csv_file)) | |
st.success("CSV input added!") | |
with tabs[3]: | |
api_url = st.text_input("API Endpoint URL") | |
api_headers = st.text_area("API Headers (JSON format, optional)", height=100) | |
if st.button("Add API Input", key="api_input"): | |
headers = {} | |
try: | |
if api_headers: | |
headers = json.loads(api_headers) | |
except Exception as e: | |
generator.log_error(f"Invalid JSON for API Headers: {e}") | |
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers})) | |
st.success("API input added!") | |
with tabs[4]: | |
db_conn = st.text_input("Database Connection String") | |
db_query = st.text_area("Database Query", height=100) | |
if st.button("Add Database Input", key="db_input"): | |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query})) | |
st.success("Database input added!") | |
def advanced_output_ui(generator: AdvancedSyntheticDataGenerator): | |
"""Display the generated synthetic data with various output options.""" | |
st.header("Synthetic Data Output") | |
if st.session_state.synthetic_data: | |
output_format = st.session_state.config.get("output_format", "plain_text") | |
if output_format == "json": | |
try: | |
json_output = json.loads(st.session_state.synthetic_data) | |
st.json(json_output) | |
except Exception: | |
st.text_area("Output", st.session_state.synthetic_data, height=300) | |
else: | |
st.text_area("Output", st.session_state.synthetic_data, height=300) | |
st.download_button("Download Output", st.session_state.synthetic_data, | |
file_name="synthetic_data.txt", mime="text/plain") | |
else: | |
st.info("No synthetic data generated yet.") | |
def advanced_logs_ui(): | |
"""Display error logs and debug information in an expandable section.""" | |
with st.expander("Error Logs & Debug Info", expanded=False): | |
if st.session_state.error_logs: | |
for log in st.session_state.error_logs: | |
st.write(log) | |
else: | |
st.write("No logs yet.") | |
# ===== MAIN APPLICATION ===== | |
def main() -> None: | |
st.set_page_config(page_title="Advanced Synthetic Data Generator", layout="wide") | |
generator = AdvancedSyntheticDataGenerator() | |
advanced_config_ui(generator) | |
# Create main tabs for Input, Output, and Logs | |
main_tabs = st.tabs(["Input", "Output", "Logs"]) | |
with main_tabs[0]: | |
advanced_input_ui(generator) | |
if st.button("Clear Inputs"): | |
st.session_state.inputs = [] | |
st.success("Inputs cleared!") | |
with main_tabs[1]: | |
if st.button("Generate Synthetic Data"): | |
with st.spinner("Generating synthetic data..."): | |
if generator.generate_synthetic_data(): | |
st.success("Data generated successfully!") | |
else: | |
st.error("Data generation failed. Check logs for details.") | |
advanced_output_ui(generator) | |
with main_tabs[2]: | |
advanced_logs_ui() | |
if __name__ == "__main__": | |
main() | |