import os import re from typing import Optional from urllib.parse import urlparse from starfish.data_ingest.formatter.template_format import PromptFormatter, QAGenerationPrompt from starfish.data_ingest.splitter.base_splitter import TextSplitter # Import parsers from parsers folder from starfish.data_ingest.parsers import ( BaseParser, PDFParser, HTMLDocumentParser, YouTubeParser, WordDocumentParser, PPTParser, TXTParser, ExcelParser, GoogleDriveParser, ) PARSER_MAPPING = { # URL patterns "youtube.com": YouTubeParser, "youtu.be": YouTubeParser, # File extensions ".pdf": PDFParser, ".html": HTMLDocumentParser, ".htm": HTMLDocumentParser, ".docx": WordDocumentParser, ".pptx": PPTParser, ".txt": TXTParser, ".xlsx": ExcelParser, } def determine_parser(file_path: str) -> BaseParser: """Determine the appropriate parser for a file or URL. Args: file_path: Path to the file or URL to parse Returns: Appropriate parser instance Raises: ValueError: If file extension is not supported FileNotFoundError: If file does not exist """ # Check if it's a URL if file_path.startswith(("http://", "https://")): for pattern, parser in PARSER_MAPPING.items(): if pattern in file_path: return parser() return HTMLDocumentParser() # Default for other URLs # File path - determine by extension if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") ext = os.path.splitext(file_path)[1].lower() if ext in PARSER_MAPPING: return PARSER_MAPPING[ext]() raise ValueError(f"Unsupported file extension: {ext}") def _generate_output_name(file_path: str) -> str: """Generate output filename based on input file or URL. Args: file_path: Path to the file or URL Returns: Generated filename with .txt extension """ if file_path.startswith(("http://", "https://")): if "youtube.com" in file_path or "youtu.be" in file_path: video_id = re.search(r"(?:v=|\.be/)([^&]+)", file_path).group(1) return f"youtube_{video_id}.txt" domain = urlparse(file_path).netloc.replace(".", "_") return f"{domain}.txt" base_name = os.path.basename(file_path) return os.path.splitext(base_name)[0] + ".txt" def process_file( file_path: str, output_dir: Optional[str] = None, output_name: Optional[str] = None, ) -> str: """Process a file using the appropriate parser. Args: file_path: Path to the file or URL to parse output_dir: Directory to save parsed text output_name: Custom filename for output Returns: Path to the output file Raises: ValueError: If output_dir is not provided """ if not output_dir: raise ValueError("Output directory must be specified") # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) # Determine and use parser parser = determine_parser(file_path) content = parser.parse(file_path) # Generate output filename output_name = output_name or _generate_output_name(file_path) if not output_name.endswith(".txt"): output_name += ".txt" # Save the content output_path = os.path.join(output_dir, output_name) parser.save(content, output_path) return output_path def generate_input_data( document_text: str, splitter: TextSplitter, prompt_formatter: PromptFormatter, # Accept any PromptFormatter implementation num_pairs: int = 5, # Optional parameter for QA-specific formatters ) -> list: """Generate input data from document text using a given PromptFormatter. Args: document_text: The text to split and process. splitter: The text splitter to use for dividing the text into chunks. prompt_formatter: An instance of a PromptFormatter subclass. num_pairs: The number of QA pairs to generate (used for QA-specific formatters). Returns: A list of formatted prompts. """ chunks = splitter.split_text(document_text) all_messages = [] # If the formatter is QAGenerationPrompt, calculate pairs_per_chunk if isinstance(prompt_formatter, QAGenerationPrompt): pairs_per_chunk = max(1, round(num_pairs / len(chunks))) prompt_formatter.num_pairs = pairs_per_chunk for chunk in chunks: # Update the text for the current chunk prompt_formatter.text = chunk # Format the prompt using the provided formatter prompt = prompt_formatter.format() all_messages.append(prompt) print(f"Processing {len(chunks)} chunks to generate prompts...") return all_messages