import json import ast import requests import streamlit as st import pdfplumber import pandas as pd import sqlalchemy from typing import Any, Dict, List, Callable # Provider clients – ensure these libraries are installed try: from openai import OpenAI except ImportError: OpenAI = None try: import groq except ImportError: groq = None # Hugging Face inference endpoint and defaults HF_API_URL: str = "https://api-inference.huggingface.co/models/" DEFAULT_TEMPERATURE: float = 0.1 GROQ_MODEL: str = "mixtral-8x7b-32768" class QADataGenerator: """ A Q&A Synthetic Generator that extracts and generates question-answer pairs from various input sources using an LLM provider. """ def __init__(self) -> None: self._setup_providers() self._setup_input_handlers() self._initialize_session_state() # Updated prompt template with dynamic {num_examples} parameter and escaped curly braces self.custom_prompt_template: str = ( "You are an expert in extracting question and answer pairs from documents. " "Generate {num_examples} Q&A pairs from the following data, formatted as a JSON list of dictionaries. " "Each dictionary must have keys 'question' and 'answer'. " "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. " "Do not hallucinate. \n\n" "Example JSON Output:\n" "[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, " "{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, " "{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n" "Now, generate {num_examples} Q&A pairs from this data:\n{data}" ) def _setup_providers(self) -> None: """Configure available LLM providers and their client initialization routines.""" self.providers: Dict[str, Dict[str, Any]] = { "Deepseek": { "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None, "models": ["deepseek-chat"], }, "OpenAI": { "client": lambda key: OpenAI(api_key=key) if OpenAI else None, "models": ["gpt-4-turbo", "gpt-3.5-turbo"], }, "Groq": { "client": lambda key: groq.Groq(api_key=key) if groq else None, "models": [GROQ_MODEL], }, "HuggingFace": { "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}}, "models": ["gpt2", "llama-2"], }, } def _setup_input_handlers(self) -> None: """Register handlers for different input data types.""" self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = { "text": self.handle_text, "pdf": self.handle_pdf, "csv": self.handle_csv, "api": self.handle_api, "db": self.handle_db, } def _initialize_session_state(self) -> None: """Initialize Streamlit session state with default configuration.""" defaults: Dict[str, Any] = { "config": { "provider": "OpenAI", "model": "gpt-4-turbo", "temperature": DEFAULT_TEMPERATURE, "num_examples": 3, # Default number of Q&A pairs }, "api_key": "", "inputs": [], # List to store input sources "qa_pairs": None, # Generated Q&A pairs output "error_logs": [], # To store any error messages } for key, value in defaults.items(): if key not in st.session_state: st.session_state[key] = value def log_error(self, message: str) -> None: """Log an error message to session state and display it.""" st.session_state.error_logs.append(message) st.error(message) # ----- Input Handlers ----- def handle_text(self, text: str) -> Dict[str, Any]: """Process plain text input.""" return {"data": text, "source": "text"} def handle_pdf(self, file) -> Dict[str, Any]: """Extract text from a PDF file.""" try: with pdfplumber.open(file) as pdf: full_text = "\n".join(page.extract_text() or "" for page in pdf.pages) return {"data": full_text, "source": "pdf"} except Exception as e: self.log_error(f"PDF Processing Error: {e}") return {"data": "", "source": "pdf"} def handle_csv(self, file) -> Dict[str, Any]: """Process a CSV file by converting it to JSON.""" try: df = pd.read_csv(file) json_data = df.to_json(orient="records") return {"data": json_data, "source": "csv"} except Exception as e: self.log_error(f"CSV Processing Error: {e}") return {"data": "", "source": "csv"} def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]: """Fetch data from an API endpoint.""" try: response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10) response.raise_for_status() return {"data": json.dumps(response.json()), "source": "api"} except Exception as e: self.log_error(f"API Processing Error: {e}") return {"data": "", "source": "api"} def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]: """Query a database using the provided connection string and SQL query.""" try: engine = sqlalchemy.create_engine(config["connection"]) with engine.connect() as conn: result = conn.execute(sqlalchemy.text(config["query"])) rows = [dict(row) for row in result] return {"data": json.dumps(rows), "source": "db"} except Exception as e: self.log_error(f"Database Processing Error: {e}") return {"data": "", "source": "db"} def aggregate_inputs(self) -> str: """Combine all input sources into a single aggregated string.""" aggregated_data = "" for item in st.session_state.inputs: aggregated_data += f"Source: {item.get('source', 'unknown')}\n" aggregated_data += item.get("data", "") + "\n\n" return aggregated_data.strip() def build_prompt(self) -> str: """ Build the complete prompt using the custom template, aggregated inputs, and the number of examples. """ data = self.aggregate_inputs() num_examples = st.session_state.config.get("num_examples", 3) prompt = self.custom_prompt_template.format(data=data, num_examples=num_examples) st.write("### Built Prompt") st.write(prompt) return prompt def generate_qa_pairs(self) -> bool: """ Generate Q&A pairs by sending the built prompt to the selected LLM provider. """ api_key: str = st.session_state.api_key if not api_key: self.log_error("API key is missing!") return False provider_name: str = st.session_state.config["provider"] provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {}) if not provider_cfg: self.log_error(f"Provider {provider_name} is not configured.") return False client_initializer: Callable[[str], Any] = provider_cfg["client"] client = client_initializer(api_key) model: str = st.session_state.config["model"] temperature: float = st.session_state.config["temperature"] prompt: str = self.build_prompt() st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**") try: if provider_name == "HuggingFace": response = self._huggingface_inference(client, prompt, model) else: response = self._standard_inference(client, prompt, model, temperature) st.write("### Raw API Response") st.write(response) qa_pairs = self._parse_response(response, provider_name) st.write("### Parsed Q&A Pairs") st.write(qa_pairs) st.session_state.qa_pairs = qa_pairs return True except Exception as e: self.log_error(f"Generation failed: {e}") return False def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any: """Inference method for providers using an OpenAI-compatible API.""" try: st.write("Sending prompt via standard inference...") result = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], temperature=temperature, ) st.write("Standard inference result received.") return result except Exception as e: self.log_error(f"Standard Inference Error: {e}") return None def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any: """Inference method for the Hugging Face Inference API.""" try: st.write("Sending prompt to HuggingFace API...") response = requests.post( HF_API_URL + model, headers=client["headers"], json={"inputs": prompt}, timeout=30, ) response.raise_for_status() st.write("HuggingFace API response received.") return response.json() except Exception as e: self.log_error(f"HuggingFace Inference Error: {e}") return None def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]: """ Parse the LLM response and return a list of Q&A pairs. Expects the response to be JSON formatted; if JSON decoding fails, uses ast.literal_eval as a fallback. """ st.write("Parsing response for provider:", provider) try: if provider == "HuggingFace": if isinstance(response, list) and response and "generated_text" in response[0]: raw_text = response[0]["generated_text"] else: self.log_error("Unexpected HuggingFace response format.") return [] else: if response and hasattr(response, "choices") and response.choices: raw_text = response.choices[0].message.content else: self.log_error("Unexpected response format from provider.") return [] try: qa_list = json.loads(raw_text) except json.JSONDecodeError as e: self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}") try: qa_list = ast.literal_eval(raw_text) except Exception as e2: self.log_error(f"ast.literal_eval failed: {e2}") return [] if isinstance(qa_list, list): return qa_list else: self.log_error("Parsed output is not a list.") return [] except Exception as e: self.log_error(f"Response Parsing Error: {e}") return [] # ============ UI Components ============ def config_ui(generator: QADataGenerator) -> None: """Display configuration options in the sidebar.""" with st.sidebar: st.header("Configuration") provider = st.selectbox("Select Provider", list(generator.providers.keys())) st.session_state.config["provider"] = provider provider_cfg = generator.providers[provider] model = st.selectbox("Select Model", provider_cfg["models"]) st.session_state.config["model"] = model temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) st.session_state.config["temperature"] = temperature num_examples = st.number_input("Number of Q&A Pairs", min_value=1, max_value=10, value=3, step=1) st.session_state.config["num_examples"] = num_examples api_key = st.text_input(f"{provider} API Key", type="password") st.session_state.api_key = api_key def input_ui(generator: QADataGenerator) -> None: """Display input data source options using tabs.""" st.subheader("Input Data Sources") tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"]) with tabs[0]: text_input = st.text_area("Enter text input", height=150) if st.button("Add Text Input", key="text_input"): if text_input.strip(): st.session_state.inputs.append(generator.handle_text(text_input)) st.success("Text input added!") else: st.warning("Empty text input.") with tabs[1]: pdf_file = st.file_uploader("Upload PDF", type=["pdf"]) if pdf_file is not None: st.session_state.inputs.append(generator.handle_pdf(pdf_file)) st.success("PDF input added!") with tabs[2]: csv_file = st.file_uploader("Upload CSV", type=["csv"]) if csv_file is not None: st.session_state.inputs.append(generator.handle_csv(csv_file)) st.success("CSV input added!") with tabs[3]: api_url = st.text_input("API Endpoint URL") api_headers = st.text_area("API Headers (JSON format, optional)", height=100) if st.button("Add API Input", key="api_input"): headers = {} try: if api_headers: headers = json.loads(api_headers) except Exception as e: generator.log_error(f"Invalid JSON for API Headers: {e}") st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers})) st.success("API input added!") with tabs[4]: db_conn = st.text_input("Database Connection String") db_query = st.text_area("Database Query", height=100) if st.button("Add Database Input", key="db_input"): st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query})) st.success("Database input added!") def output_ui(generator: QADataGenerator) -> None: """Display the generated Q&A pairs and provide download options.""" st.subheader("Q&A Pairs Output") if st.session_state.qa_pairs: st.write("### Generated Q&A Pairs") st.write(st.session_state.qa_pairs) # Download as JSON st.download_button( "Download as JSON", json.dumps(st.session_state.qa_pairs, indent=2), file_name="qa_pairs.json", mime="application/json" ) # Download as CSV try: df = pd.DataFrame(st.session_state.qa_pairs) csv_data = df.to_csv(index=False) st.download_button( "Download as CSV", csv_data, file_name="qa_pairs.csv", mime="text/csv" ) except Exception as e: st.error(f"Error generating CSV: {e}") else: st.info("No Q&A pairs generated yet.") def logs_ui() -> None: """Display error logs and debugging information in an expandable section.""" with st.expander("Error Logs & Debug Info", expanded=False): if st.session_state.error_logs: for log in st.session_state.error_logs: st.write(log) else: st.write("No logs yet.") def main() -> None: """Main Streamlit application entry point.""" st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide") st.title("Advanced Q&A Synthetic Generator") st.markdown( """ Welcome to the Advanced Q&A Synthetic Generator. This tool extracts and generates question-answer pairs from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs. """ ) # Initialize generator and display configuration UI generator = QADataGenerator() config_ui(generator) st.header("1. Input Data") input_ui(generator) if st.button("Clear All Inputs"): st.session_state.inputs = [] st.success("All inputs have been cleared!") st.header("2. Generate Q&A Pairs") if st.button("Generate Q&A Pairs", key="generate_qa"): with st.spinner("Generating Q&A pairs..."): if generator.generate_qa_pairs(): st.success("Q&A pairs generated successfully!") else: st.error("Q&A generation failed. Check logs for details.") st.header("3. Output") output_ui(generator) st.header("4. Logs & Debug Information") logs_ui() if __name__ == "__main__": main()