Spaces:
Sleeping
Sleeping
import json | |
import ast | |
import requests | |
import streamlit as st | |
import pdfplumber | |
import pandas as pd | |
import sqlalchemy | |
from typing import Any, Dict, List, Callable | |
# Provider clients – ensure these libraries are installed | |
try: | |
from openai import OpenAI | |
except ImportError: | |
OpenAI = None | |
try: | |
import groq | |
except ImportError: | |
groq = None | |
# Hugging Face inference endpoint and defaults | |
HF_API_URL: str = "https://api-inference.huggingface.co/models/" | |
DEFAULT_TEMPERATURE: float = 0.1 | |
GROQ_MODEL: str = "mixtral-8x7b-32768" | |
class QADataGenerator: | |
""" | |
A Q&A Synthetic Generator that extracts and generates question-answer pairs | |
from various input sources using an LLM provider. | |
""" | |
def __init__(self) -> None: | |
self._setup_providers() | |
self._setup_input_handlers() | |
self._initialize_session_state() | |
# Updated prompt template with dynamic {num_examples} parameter and escaped curly braces | |
self.custom_prompt_template: str = ( | |
"You are an expert in extracting question and answer pairs from documents. " | |
"Generate {num_examples} Q&A pairs from the following data, formatted as a JSON list of dictionaries. " | |
"Each dictionary must have keys 'question' and 'answer'. " | |
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. " | |
"Do not hallucinate. \n\n" | |
"Example JSON Output:\n" | |
"[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, " | |
"{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, " | |
"{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n" | |
"Now, generate {num_examples} Q&A pairs from this data:\n{data}" | |
) | |
def _setup_providers(self) -> None: | |
"""Configure available LLM providers and their client initialization routines.""" | |
self.providers: Dict[str, Dict[str, Any]] = { | |
"Deepseek": { | |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key) if OpenAI else None, | |
"models": ["deepseek-chat"], | |
}, | |
"OpenAI": { | |
"client": lambda key: OpenAI(api_key=key) if OpenAI else None, | |
"models": ["gpt-4-turbo", "gpt-3.5-turbo"], | |
}, | |
"Groq": { | |
"client": lambda key: groq.Groq(api_key=key) if groq else None, | |
"models": [GROQ_MODEL], | |
}, | |
"HuggingFace": { | |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}}, | |
"models": ["gpt2", "llama-2"], | |
}, | |
} | |
def _setup_input_handlers(self) -> None: | |
"""Register handlers for different input data types.""" | |
self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = { | |
"text": self.handle_text, | |
"pdf": self.handle_pdf, | |
"csv": self.handle_csv, | |
"api": self.handle_api, | |
"db": self.handle_db, | |
} | |
def _initialize_session_state(self) -> None: | |
"""Initialize Streamlit session state with default configuration.""" | |
defaults: Dict[str, Any] = { | |
"config": { | |
"provider": "OpenAI", | |
"model": "gpt-4-turbo", | |
"temperature": DEFAULT_TEMPERATURE, | |
"num_examples": 3, # Default number of Q&A pairs | |
}, | |
"api_key": "", | |
"inputs": [], # List to store input sources | |
"qa_pairs": None, # Generated Q&A pairs output | |
"error_logs": [], # To store any error messages | |
} | |
for key, value in defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
def log_error(self, message: str) -> None: | |
"""Log an error message to session state and display it.""" | |
st.session_state.error_logs.append(message) | |
st.error(message) | |
# ----- Input Handlers ----- | |
def handle_text(self, text: str) -> Dict[str, Any]: | |
"""Process plain text input.""" | |
return {"data": text, "source": "text"} | |
def handle_pdf(self, file) -> Dict[str, Any]: | |
"""Extract text from a PDF file.""" | |
try: | |
with pdfplumber.open(file) as pdf: | |
full_text = "\n".join(page.extract_text() or "" for page in pdf.pages) | |
return {"data": full_text, "source": "pdf"} | |
except Exception as e: | |
self.log_error(f"PDF Processing Error: {e}") | |
return {"data": "", "source": "pdf"} | |
def handle_csv(self, file) -> Dict[str, Any]: | |
"""Process a CSV file by converting it to JSON.""" | |
try: | |
df = pd.read_csv(file) | |
json_data = df.to_json(orient="records") | |
return {"data": json_data, "source": "csv"} | |
except Exception as e: | |
self.log_error(f"CSV Processing Error: {e}") | |
return {"data": "", "source": "csv"} | |
def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]: | |
"""Fetch data from an API endpoint.""" | |
try: | |
response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10) | |
response.raise_for_status() | |
return {"data": json.dumps(response.json()), "source": "api"} | |
except Exception as e: | |
self.log_error(f"API Processing Error: {e}") | |
return {"data": "", "source": "api"} | |
def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]: | |
"""Query a database using the provided connection string and SQL query.""" | |
try: | |
engine = sqlalchemy.create_engine(config["connection"]) | |
with engine.connect() as conn: | |
result = conn.execute(sqlalchemy.text(config["query"])) | |
rows = [dict(row) for row in result] | |
return {"data": json.dumps(rows), "source": "db"} | |
except Exception as e: | |
self.log_error(f"Database Processing Error: {e}") | |
return {"data": "", "source": "db"} | |
def aggregate_inputs(self) -> str: | |
"""Combine all input sources into a single aggregated string.""" | |
aggregated_data = "" | |
for item in st.session_state.inputs: | |
aggregated_data += f"Source: {item.get('source', 'unknown')}\n" | |
aggregated_data += item.get("data", "") + "\n\n" | |
return aggregated_data.strip() | |
def build_prompt(self) -> str: | |
""" | |
Build the complete prompt using the custom template, aggregated inputs, | |
and the number of examples. | |
""" | |
data = self.aggregate_inputs() | |
num_examples = st.session_state.config.get("num_examples", 3) | |
prompt = self.custom_prompt_template.format(data=data, num_examples=num_examples) | |
st.write("### Built Prompt") | |
st.write(prompt) | |
return prompt | |
def generate_qa_pairs(self) -> bool: | |
""" | |
Generate Q&A pairs by sending the built prompt to the selected LLM provider. | |
""" | |
api_key: str = st.session_state.api_key | |
if not api_key: | |
self.log_error("API key is missing!") | |
return False | |
provider_name: str = st.session_state.config["provider"] | |
provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {}) | |
if not provider_cfg: | |
self.log_error(f"Provider {provider_name} is not configured.") | |
return False | |
client_initializer: Callable[[str], Any] = provider_cfg["client"] | |
client = client_initializer(api_key) | |
model: str = st.session_state.config["model"] | |
temperature: float = st.session_state.config["temperature"] | |
prompt: str = self.build_prompt() | |
st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**") | |
try: | |
if provider_name == "HuggingFace": | |
response = self._huggingface_inference(client, prompt, model) | |
else: | |
response = self._standard_inference(client, prompt, model, temperature) | |
st.write("### Raw API Response") | |
st.write(response) | |
qa_pairs = self._parse_response(response, provider_name) | |
st.write("### Parsed Q&A Pairs") | |
st.write(qa_pairs) | |
st.session_state.qa_pairs = qa_pairs | |
return True | |
except Exception as e: | |
self.log_error(f"Generation failed: {e}") | |
return False | |
def _standard_inference(self, client: Any, prompt: str, model: str, temperature: float) -> Any: | |
"""Inference method for providers using an OpenAI-compatible API.""" | |
try: | |
st.write("Sending prompt via standard inference...") | |
result = client.chat.completions.create( | |
model=model, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=temperature, | |
) | |
st.write("Standard inference result received.") | |
return result | |
except Exception as e: | |
self.log_error(f"Standard Inference Error: {e}") | |
return None | |
def _huggingface_inference(self, client: Dict[str, Any], prompt: str, model: str) -> Any: | |
"""Inference method for the Hugging Face Inference API.""" | |
try: | |
st.write("Sending prompt to HuggingFace API...") | |
response = requests.post( | |
HF_API_URL + model, | |
headers=client["headers"], | |
json={"inputs": prompt}, | |
timeout=30, | |
) | |
response.raise_for_status() | |
st.write("HuggingFace API response received.") | |
return response.json() | |
except Exception as e: | |
self.log_error(f"HuggingFace Inference Error: {e}") | |
return None | |
def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]: | |
""" | |
Parse the LLM response and return a list of Q&A pairs. | |
Expects the response to be JSON formatted; if JSON decoding fails, | |
uses ast.literal_eval as a fallback. | |
""" | |
st.write("Parsing response for provider:", provider) | |
try: | |
if provider == "HuggingFace": | |
if isinstance(response, list) and response and "generated_text" in response[0]: | |
raw_text = response[0]["generated_text"] | |
else: | |
self.log_error("Unexpected HuggingFace response format.") | |
return [] | |
else: | |
if response and hasattr(response, "choices") and response.choices: | |
raw_text = response.choices[0].message.content | |
else: | |
self.log_error("Unexpected response format from provider.") | |
return [] | |
try: | |
qa_list = json.loads(raw_text) | |
except json.JSONDecodeError as e: | |
self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}") | |
try: | |
qa_list = ast.literal_eval(raw_text) | |
except Exception as e2: | |
self.log_error(f"ast.literal_eval failed: {e2}") | |
return [] | |
if isinstance(qa_list, list): | |
return qa_list | |
else: | |
self.log_error("Parsed output is not a list.") | |
return [] | |
except Exception as e: | |
self.log_error(f"Response Parsing Error: {e}") | |
return [] | |
# ============ UI Components ============ | |
def config_ui(generator: QADataGenerator) -> None: | |
"""Display configuration options in the sidebar.""" | |
with st.sidebar: | |
st.header("Configuration") | |
provider = st.selectbox("Select Provider", list(generator.providers.keys())) | |
st.session_state.config["provider"] = provider | |
provider_cfg = generator.providers[provider] | |
model = st.selectbox("Select Model", provider_cfg["models"]) | |
st.session_state.config["model"] = model | |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) | |
st.session_state.config["temperature"] = temperature | |
num_examples = st.number_input("Number of Q&A Pairs", min_value=1, max_value=10, value=3, step=1) | |
st.session_state.config["num_examples"] = num_examples | |
api_key = st.text_input(f"{provider} API Key", type="password") | |
st.session_state.api_key = api_key | |
def input_ui(generator: QADataGenerator) -> None: | |
"""Display input data source options using tabs.""" | |
st.subheader("Input Data Sources") | |
tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"]) | |
with tabs[0]: | |
text_input = st.text_area("Enter text input", height=150) | |
if st.button("Add Text Input", key="text_input"): | |
if text_input.strip(): | |
st.session_state.inputs.append(generator.handle_text(text_input)) | |
st.success("Text input added!") | |
else: | |
st.warning("Empty text input.") | |
with tabs[1]: | |
pdf_file = st.file_uploader("Upload PDF", type=["pdf"]) | |
if pdf_file is not None: | |
st.session_state.inputs.append(generator.handle_pdf(pdf_file)) | |
st.success("PDF input added!") | |
with tabs[2]: | |
csv_file = st.file_uploader("Upload CSV", type=["csv"]) | |
if csv_file is not None: | |
st.session_state.inputs.append(generator.handle_csv(csv_file)) | |
st.success("CSV input added!") | |
with tabs[3]: | |
api_url = st.text_input("API Endpoint URL") | |
api_headers = st.text_area("API Headers (JSON format, optional)", height=100) | |
if st.button("Add API Input", key="api_input"): | |
headers = {} | |
try: | |
if api_headers: | |
headers = json.loads(api_headers) | |
except Exception as e: | |
generator.log_error(f"Invalid JSON for API Headers: {e}") | |
st.session_state.inputs.append(generator.handle_api({"url": api_url, "headers": headers})) | |
st.success("API input added!") | |
with tabs[4]: | |
db_conn = st.text_input("Database Connection String") | |
db_query = st.text_area("Database Query", height=100) | |
if st.button("Add Database Input", key="db_input"): | |
st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query})) | |
st.success("Database input added!") | |
def output_ui(generator: QADataGenerator) -> None: | |
"""Display the generated Q&A pairs and provide download options.""" | |
st.subheader("Q&A Pairs Output") | |
if st.session_state.qa_pairs: | |
st.write("### Generated Q&A Pairs") | |
st.write(st.session_state.qa_pairs) | |
# Download as JSON | |
st.download_button( | |
"Download as JSON", | |
json.dumps(st.session_state.qa_pairs, indent=2), | |
file_name="qa_pairs.json", | |
mime="application/json" | |
) | |
# Download as CSV | |
try: | |
df = pd.DataFrame(st.session_state.qa_pairs) | |
csv_data = df.to_csv(index=False) | |
st.download_button( | |
"Download as CSV", | |
csv_data, | |
file_name="qa_pairs.csv", | |
mime="text/csv" | |
) | |
except Exception as e: | |
st.error(f"Error generating CSV: {e}") | |
else: | |
st.info("No Q&A pairs generated yet.") | |
def logs_ui() -> None: | |
"""Display error logs and debugging information in an expandable section.""" | |
with st.expander("Error Logs & Debug Info", expanded=False): | |
if st.session_state.error_logs: | |
for log in st.session_state.error_logs: | |
st.write(log) | |
else: | |
st.write("No logs yet.") | |
def main() -> None: | |
"""Main Streamlit application entry point.""" | |
st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide") | |
st.title("Advanced Q&A Synthetic Generator") | |
st.markdown( | |
""" | |
Welcome to the Advanced Q&A Synthetic Generator. This tool extracts and generates question-answer pairs | |
from various input sources. Configure your provider in the sidebar, add input data, and click the button below to generate Q&A pairs. | |
""" | |
) | |
# Initialize generator and display configuration UI | |
generator = QADataGenerator() | |
config_ui(generator) | |
st.header("1. Input Data") | |
input_ui(generator) | |
if st.button("Clear All Inputs"): | |
st.session_state.inputs = [] | |
st.success("All inputs have been cleared!") | |
st.header("2. Generate Q&A Pairs") | |
if st.button("Generate Q&A Pairs", key="generate_qa"): | |
with st.spinner("Generating Q&A pairs..."): | |
if generator.generate_qa_pairs(): | |
st.success("Q&A pairs generated successfully!") | |
else: | |
st.error("Q&A generation failed. Check logs for details.") | |
st.header("3. Output") | |
output_ui(generator) | |
st.header("4. Logs & Debug Information") | |
logs_ui() | |
if __name__ == "__main__": | |
main() | |