sythenticdata / app.py
mgbam's picture
Update app.py
68019c9 verified
import json
import ast
import requests
import streamlit as st
import pdfplumber
import pandas as pd
import sqlalchemy
from typing import Any, Dict, List, Callable
# Provider clients – ensure these libraries are installed
try:
from openai import OpenAI
except ImportError:
OpenAI = None
try:
import groq
except ImportError:
groq = None
# Hugging Face inference endpoint and defaults
HF_API_URL: str = "https://api-inference.huggingface.co/models/"
DEFAULT_TEMPERATURE: float = 0.1
GROQ_MODEL: str = "mixtral-8x7b-32768"
class QADataGenerator:
"""
A Q&A Synthetic Generator that extracts and generates question-answer pairs
from various input sources using an LLM provider.
"""
def __init__(self) -> None:
self._setup_providers()
self._setup_input_handlers()
self._initialize_session_state()
# Updated prompt template with dynamic {num_examples} parameter and escaped curly braces
self.custom_prompt_template: str = (
"You are an expert in extracting question and answer pairs from documents. "
"Generate {num_examples} Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
"Each dictionary must have keys 'question' and 'answer'. "
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
"Do not hallucinate. \n\n"
"Example JSON Output:\n"
"[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
"{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
"{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
"Now, generate {num_examples} Q&A pairs from this data:\n{data}"
)
def _setup_providers(self) -> None:
"""Configure available LLM providers and their client initialization routines."""
self.providers: Dict[str, Dict[str, Any]] = {
"Deepseek": {
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None,
"models": ["deepseek-chat"],
},
"OpenAI": {
"client": lambda key: OpenAI(api_key=key) if OpenAI else None,
"models": ["gpt-4-turbo", "gpt-3.5-turbo"],
},
"Groq": {
"client": lambda key: groq.Groq(api_key=key) if groq else None,
"models": [GROQ_MODEL],
},
"HuggingFace": {
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
"models": ["gpt2", "llama-2"],
},
}
def _setup_input_handlers(self) -> None:
"""Register handlers for different input data types."""
self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
"text": self.handle_text,
"pdf": self.handle_pdf,
"csv": self.handle_csv,
"api": self.handle_api,
"db": self.handle_db,
}
def _initialize_session_state(self) -> None:
"""Initialize Streamlit session state with default configuration."""
defaults: Dict[str, Any] = {
"config": {
"provider": "OpenAI",
"model": "gpt-4-turbo",
"temperature": DEFAULT_TEMPERATURE,
"num_examples": 3, # Default number of Q&A pairs
},
"api_key": "",
"inputs": [], # List to store input sources
"qa_pairs": None, # Generated Q&A pairs output
"error_logs": [], # To store any error messages
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
def log_error(self, message: str) -> None:
"""Log an error message to session state and display it."""
st.session_state.error_logs.append(message)
st.error(message)
# ----- Input Handlers -----
def handle_text(self, text: str) -> Dict[str, Any]:
"""Process plain text input."""
return {"data": text, "source": "text"}
def handle_pdf(self, file) -> Dict[str, Any]:
"""Extract text from a PDF file."""
try:
with pdfplumber.open(file) as pdf:
full_text = "\n".join(page.extract_text() or "" for page in pdf.pages)
return {"data": full_text, "source": "pdf"}
except Exception as e:
self.log_error(f"PDF Processing Error: {e}")
return {"data": "", "source": "pdf"}
def handle_csv(self, file) -> Dict[str, Any]:
"""Process a CSV file by converting it to JSON."""
try:
df = pd.read_csv(file)
json_data = df.to_json(orient="records")
return {"data": json_data, "source": "csv"}
except Exception as e:
self.log_error(f"CSV Processing Error: {e}")
return {"data": "", "source": "csv"}
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
"""Fetch data from an API endpoint."""
try:
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
response.raise_for_status()
return {"data": json.dumps(response.json()), "source": "api"}
except Exception as e:
self.log_error(f"API Processing Error: {e}")
return {"data": "", "source": "api"}
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
"""Query a database using the provided connection string and SQL query."""
try:
engine = sqlalchemy.create_engine(config["connection"])
with engine.connect() as conn:
result = conn.execute(sqlalchemy.text(config["query"]))
rows = [dict(row) for row in result]
return {"data": json.dumps(rows), "source": "db"}
except Exception as e:
self.log_error(f"Database Processing Error: {e}")
return {"data": "", "source": "db"}
def aggregate_inputs(self) -> str:
"""Combine all input sources into a single aggregated string."""
aggregated_data = ""
for item in st.session_state.inputs:
aggregated_data += f"Source: {item.get('source', 'unknown')}\n"
aggregated_data += item.get("data", "") + "\n\n"
return aggregated_data.strip()
def build_prompt(self) -> str:
"""
Build the complete prompt using the custom template, aggregated inputs,
and the number of examples.
"""
data = self.aggregate_inputs()
num_examples = st.session_state.config.get("num_examples", 3)
prompt = self.custom_prompt_template.format(data=data, num_examples=num_examples)
st.write("### Built Prompt")
st.write(prompt)
return prompt
def generate_qa_pairs(self) -> bool:
"""
Generate Q&A pairs by sending the built prompt to the selected LLM provider.
"""
api_key: str = st.session_state.api_key
if not api_key:
self.log_error("API key is missing!")
return False
provider_name: str = st.session_state.config["provider"]
provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
if not provider_cfg:
self.log_error(f"Provider {provider_name} is not configured.")
return False
client_initializer: Callable[[str], Any] = provider_cfg["client"]
client = client_initializer(api_key)
model: str = st.session_state.config["model"]
temperature: float = st.session_state.config["temperature"]
prompt: str = self.build_prompt()
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
try:
if provider_name == "HuggingFace":
response = self._huggingface_inference(client, prompt, model)
else:
response = self._standard_inference(client, prompt, model, temperature)
st.write("### Raw API Response")
st.write(response)
qa_pairs = self._parse_response(response, provider_name)
st.write("### Parsed Q&A Pairs")
st.write(qa_pairs)
st.session_state.qa_pairs = qa_pairs
return True
except Exception as e:
self.log_error(f"Generation failed: {e}")
return False
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any:
"""Inference method for providers using an OpenAI-compatible API."""
try:
st.write("Sending prompt via standard inference...")
result = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
)
st.write("Standard inference result received.")
return result
except Exception as e:
self.log_error(f"Standard Inference Error: {e}")
return None
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any:
"""Inference method for the Hugging Face Inference API."""
try:
st.write("Sending prompt to HuggingFace API...")
response = requests.post(
HF_API_URL + model,
headers=client["headers"],
json={"inputs": prompt},
timeout=30,
)
response.raise_for_status()
st.write("HuggingFace API response received.")
return response.json()
except Exception as e:
self.log_error(f"HuggingFace Inference Error: {e}")
return None
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
"""
Parse the LLM response and return a list of Q&A pairs.
Expects the response to be JSON formatted; if JSON decoding fails,
uses ast.literal_eval as a fallback.
"""
st.write("Parsing response for provider:", provider)
try:
if provider == "HuggingFace":
if isinstance(response, list) and response and "generated_text" in response[0]:
raw_text = response[0]["generated_text"]
else:
self.log_error("Unexpected HuggingFace response format.")
return []
else:
if response and hasattr(response, "choices") and response.choices:
raw_text = response.choices[0].message.content
else:
self.log_error("Unexpected response format from provider.")
return []
try:
qa_list = json.loads(raw_text)
except json.JSONDecodeError as e:
self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
try:
qa_list = ast.literal_eval(raw_text)
except Exception as e2:
self.log_error(f"ast.literal_eval failed: {e2}")
return []
if isinstance(qa_list, list):
return qa_list
else:
self.log_error("Parsed output is not a list.")
return []
except Exception as e:
self.log_error(f"Response Parsing Error: {e}")
return []
# ============ UI Components ============
def config_ui(generator: QADataGenerator) -> None:
"""Display configuration options in the sidebar."""
with st.sidebar:
st.header("Configuration")
provider = st.selectbox("Select Provider", list(generator.providers.keys()))
st.session_state.config["provider"] = provider
provider_cfg = generator.providers[provider]
model = st.selectbox("Select Model", provider_cfg["models"])
st.session_state.config["model"] = model
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
st.session_state.config["temperature"] = temperature
num_examples = st.number_input("Number of Q&A Pairs", min_value=1, max_value=10, value=3, step=1)
st.session_state.config["num_examples"] = num_examples
api_key = st.text_input(f"{provider} API Key", type="password")
st.session_state.api_key = api_key
def input_ui(generator: QADataGenerator) -> None:
"""Display input data source options using tabs."""
st.subheader("Input Data Sources")
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
with tabs[0]:
text_input = st.text_area("Enter text input", height=150)
if st.button("Add Text Input", key="text_input"):
if text_input.strip():
st.session_state.inputs.append(generator.handle_text(text_input))
st.success("Text input added!")
else:
st.warning("Empty text input.")
with tabs[1]:
pdf_file = st.file_uploader("Upload PDF", type=["pdf"])
if pdf_file is not None:
st.session_state.inputs.append(generator.handle_pdf(pdf_file))
st.success("PDF input added!")
with tabs[2]:
csv_file = st.file_uploader("Upload CSV", type=["csv"])
if csv_file is not None:
st.session_state.inputs.append(generator.handle_csv(csv_file))
st.success("CSV input added!")
with tabs[3]:
api_url = st.text_input("API Endpoint URL")
api_headers = st.text_area("API Headers (JSON format, optional)", height=100)
if st.button("Add API Input", key="api_input"):
headers = {}
try:
if api_headers:
headers = json.loads(api_headers)
except Exception as e:
generator.log_error(f"Invalid JSON for API Headers: {e}")
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers}))
st.success("API input added!")
with tabs[4]:
db_conn = st.text_input("Database Connection String")
db_query = st.text_area("Database Query", height=100)
if st.button("Add Database Input", key="db_input"):
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
st.success("Database input added!")
def output_ui(generator: QADataGenerator) -> None:
"""Display the generated Q&A pairs and provide download options."""
st.subheader("Q&A Pairs Output")
if st.session_state.qa_pairs:
st.write("### Generated Q&A Pairs")
st.write(st.session_state.qa_pairs)
# Download as JSON
st.download_button(
"Download as JSON",
json.dumps(st.session_state.qa_pairs, indent=2),
file_name="qa_pairs.json",
mime="application/json"
)
# Download as CSV
try:
df = pd.DataFrame(st.session_state.qa_pairs)
csv_data = df.to_csv(index=False)
st.download_button(
"Download as CSV",
csv_data,
file_name="qa_pairs.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"Error generating CSV: {e}")
else:
st.info("No Q&A pairs generated yet.")
def logs_ui() -> None:
"""Display error logs and debugging information in an expandable section."""
with st.expander("Error Logs & Debug Info", expanded=False):
if st.session_state.error_logs:
for log in st.session_state.error_logs:
st.write(log)
else:
st.write("No logs yet.")
def main() -> None:
"""Main Streamlit application entry point."""
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
st.title("Advanced Q&A Synthetic Generator")
st.markdown(
"""
Welcome to the Advanced Q&A Synthetic Generator. This tool extracts and generates question-answer pairs
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs.
"""
)
# Initialize generator and display configuration UI
generator = QADataGenerator()
config_ui(generator)
st.header("1. Input Data")
input_ui(generator)
if st.button("Clear All Inputs"):
st.session_state.inputs = []
st.success("All inputs have been cleared!")
st.header("2. Generate Q&A Pairs")
if st.button("Generate Q&A Pairs", key="generate_qa"):
with st.spinner("Generating Q&A pairs..."):
if generator.generate_qa_pairs():
st.success("Q&A pairs generated successfully!")
else:
st.error("Q&A generation failed. Check logs for details.")
st.header("3. Output")
output_ui(generator)
st.header("4. Logs & Debug Information")
logs_ui()
if __name__ == "__main__":
main()