File size: 4,829 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import re
from typing import Optional
from urllib.parse import urlparse
from starfish.data_ingest.formatter.template_format import PromptFormatter, QAGenerationPrompt
from starfish.data_ingest.splitter.base_splitter import TextSplitter


# Import parsers from parsers folder
from starfish.data_ingest.parsers import (
    BaseParser,
    PDFParser,
    HTMLDocumentParser,
    YouTubeParser,
    WordDocumentParser,
    PPTParser,
    TXTParser,
    ExcelParser,
    GoogleDriveParser,
)


PARSER_MAPPING = {
    # URL patterns
    "youtube.com": YouTubeParser,
    "youtu.be": YouTubeParser,
    # File extensions
    ".pdf": PDFParser,
    ".html": HTMLDocumentParser,
    ".htm": HTMLDocumentParser,
    ".docx": WordDocumentParser,
    ".pptx": PPTParser,
    ".txt": TXTParser,
    ".xlsx": ExcelParser,
}


def determine_parser(file_path: str) -> BaseParser:
    """Determine the appropriate parser for a file or URL.

    Args:
        file_path: Path to the file or URL to parse

    Returns:
        Appropriate parser instance

    Raises:
        ValueError: If file extension is not supported
        FileNotFoundError: If file does not exist
    """
    # Check if it's a URL
    if file_path.startswith(("http://", "https://")):
        for pattern, parser in PARSER_MAPPING.items():
            if pattern in file_path:
                return parser()
        return HTMLDocumentParser()  # Default for other URLs

    # File path - determine by extension
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")

    ext = os.path.splitext(file_path)[1].lower()
    if ext in PARSER_MAPPING:
        return PARSER_MAPPING[ext]()

    raise ValueError(f"Unsupported file extension: {ext}")


def _generate_output_name(file_path: str) -> str:
    """Generate output filename based on input file or URL.

    Args:
        file_path: Path to the file or URL

    Returns:
        Generated filename with .txt extension
    """
    if file_path.startswith(("http://", "https://")):
        if "youtube.com" in file_path or "youtu.be" in file_path:
            video_id = re.search(r"(?:v=|\.be/)([^&]+)", file_path).group(1)
            return f"youtube_{video_id}.txt"
        domain = urlparse(file_path).netloc.replace(".", "_")
        return f"{domain}.txt"

    base_name = os.path.basename(file_path)
    return os.path.splitext(base_name)[0] + ".txt"


def process_file(
    file_path: str,
    output_dir: Optional[str] = None,
    output_name: Optional[str] = None,
) -> str:
    """Process a file using the appropriate parser.

    Args:
        file_path: Path to the file or URL to parse
        output_dir: Directory to save parsed text
        output_name: Custom filename for output

    Returns:
        Path to the output file

    Raises:
        ValueError: If output_dir is not provided
    """
    if not output_dir:
        raise ValueError("Output directory must be specified")

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Determine and use parser
    parser = determine_parser(file_path)
    content = parser.parse(file_path)

    # Generate output filename
    output_name = output_name or _generate_output_name(file_path)
    if not output_name.endswith(".txt"):
        output_name += ".txt"

    # Save the content
    output_path = os.path.join(output_dir, output_name)
    parser.save(content, output_path)

    return output_path


def generate_input_data(
    document_text: str,
    splitter: TextSplitter,
    prompt_formatter: PromptFormatter,  # Accept any PromptFormatter implementation
    num_pairs: int = 5,  # Optional parameter for QA-specific formatters
) -> list:
    """Generate input data from document text using a given PromptFormatter.

    Args:
        document_text: The text to split and process.
        splitter: The text splitter to use for dividing the text into chunks.
        prompt_formatter: An instance of a PromptFormatter subclass.
        num_pairs: The number of QA pairs to generate (used for QA-specific formatters).

    Returns:
        A list of formatted prompts.
    """
    chunks = splitter.split_text(document_text)
    all_messages = []

    # If the formatter is QAGenerationPrompt, calculate pairs_per_chunk
    if isinstance(prompt_formatter, QAGenerationPrompt):
        pairs_per_chunk = max(1, round(num_pairs / len(chunks)))
        prompt_formatter.num_pairs = pairs_per_chunk

    for chunk in chunks:
        # Update the text for the current chunk
        prompt_formatter.text = chunk
        # Format the prompt using the provided formatter
        prompt = prompt_formatter.format()
        all_messages.append(prompt)

    print(f"Processing {len(chunks)} chunks to generate prompts...")
    return all_messages