Spaces:
Running
Running
File size: 4,829 Bytes
5301c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import re
from typing import Optional
from urllib.parse import urlparse
from starfish.data_ingest.formatter.template_format import PromptFormatter, QAGenerationPrompt
from starfish.data_ingest.splitter.base_splitter import TextSplitter
# Import parsers from parsers folder
from starfish.data_ingest.parsers import (
BaseParser,
PDFParser,
HTMLDocumentParser,
YouTubeParser,
WordDocumentParser,
PPTParser,
TXTParser,
ExcelParser,
GoogleDriveParser,
)
PARSER_MAPPING = {
# URL patterns
"youtube.com": YouTubeParser,
"youtu.be": YouTubeParser,
# File extensions
".pdf": PDFParser,
".html": HTMLDocumentParser,
".htm": HTMLDocumentParser,
".docx": WordDocumentParser,
".pptx": PPTParser,
".txt": TXTParser,
".xlsx": ExcelParser,
}
def determine_parser(file_path: str) -> BaseParser:
"""Determine the appropriate parser for a file or URL.
Args:
file_path: Path to the file or URL to parse
Returns:
Appropriate parser instance
Raises:
ValueError: If file extension is not supported
FileNotFoundError: If file does not exist
"""
# Check if it's a URL
if file_path.startswith(("http://", "https://")):
for pattern, parser in PARSER_MAPPING.items():
if pattern in file_path:
return parser()
return HTMLDocumentParser() # Default for other URLs
# File path - determine by extension
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
ext = os.path.splitext(file_path)[1].lower()
if ext in PARSER_MAPPING:
return PARSER_MAPPING[ext]()
raise ValueError(f"Unsupported file extension: {ext}")
def _generate_output_name(file_path: str) -> str:
"""Generate output filename based on input file or URL.
Args:
file_path: Path to the file or URL
Returns:
Generated filename with .txt extension
"""
if file_path.startswith(("http://", "https://")):
if "youtube.com" in file_path or "youtu.be" in file_path:
video_id = re.search(r"(?:v=|\.be/)([^&]+)", file_path).group(1)
return f"youtube_{video_id}.txt"
domain = urlparse(file_path).netloc.replace(".", "_")
return f"{domain}.txt"
base_name = os.path.basename(file_path)
return os.path.splitext(base_name)[0] + ".txt"
def process_file(
file_path: str,
output_dir: Optional[str] = None,
output_name: Optional[str] = None,
) -> str:
"""Process a file using the appropriate parser.
Args:
file_path: Path to the file or URL to parse
output_dir: Directory to save parsed text
output_name: Custom filename for output
Returns:
Path to the output file
Raises:
ValueError: If output_dir is not provided
"""
if not output_dir:
raise ValueError("Output directory must be specified")
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Determine and use parser
parser = determine_parser(file_path)
content = parser.parse(file_path)
# Generate output filename
output_name = output_name or _generate_output_name(file_path)
if not output_name.endswith(".txt"):
output_name += ".txt"
# Save the content
output_path = os.path.join(output_dir, output_name)
parser.save(content, output_path)
return output_path
def generate_input_data(
document_text: str,
splitter: TextSplitter,
prompt_formatter: PromptFormatter, # Accept any PromptFormatter implementation
num_pairs: int = 5, # Optional parameter for QA-specific formatters
) -> list:
"""Generate input data from document text using a given PromptFormatter.
Args:
document_text: The text to split and process.
splitter: The text splitter to use for dividing the text into chunks.
prompt_formatter: An instance of a PromptFormatter subclass.
num_pairs: The number of QA pairs to generate (used for QA-specific formatters).
Returns:
A list of formatted prompts.
"""
chunks = splitter.split_text(document_text)
all_messages = []
# If the formatter is QAGenerationPrompt, calculate pairs_per_chunk
if isinstance(prompt_formatter, QAGenerationPrompt):
pairs_per_chunk = max(1, round(num_pairs / len(chunks)))
prompt_formatter.num_pairs = pairs_per_chunk
for chunk in chunks:
# Update the text for the current chunk
prompt_formatter.text = chunk
# Format the prompt using the provided formatter
prompt = prompt_formatter.format()
all_messages.append(prompt)
print(f"Processing {len(chunks)} chunks to generate prompts...")
return all_messages
|