File size: 16,361 Bytes
685adc8
 
aad7490
685adc8
aad7490
 
685adc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f14b334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685adc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07de967
685adc8
1d69af8
685adc8
 
 
 
 
 
 
 
 
 
 
 
 
f14b334
 
 
 
 
 
 
 
685adc8
 
 
 
 
 
 
 
 
 
 
f14b334
 
 
 
685adc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f14b334
 
 
 
 
 
685adc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import gradio as gr
import pandas as pd
from utils.google_genai_llm import get_response, generate_with_gemini
from prompts.requirements_gathering import requirements_gathering_system_prompt
from prompts.planning import hf_query_gen_prompt

from PIL import Image
import os
import tempfile
import traceback
import hashlib

# Import Marker for document processing
try:
    from marker.converters.pdf import PdfConverter
    from marker.models import create_model_dict
    from marker.output import text_from_rendered
    MARKER_AVAILABLE = True
except ImportError:
    MARKER_AVAILABLE = False
    print("Warning: Marker library not available. PDF, PPT, and DOCX processing will be limited.")

def get_file_hash(file_path):
    """Generate a hash of the file for caching purposes"""
    try:
        with open(file_path, 'rb') as f:
            file_hash = hashlib.md5(f.read()).hexdigest()
        return file_hash
    except Exception:
        return None

def extract_text_with_marker(file_path):
    """Extract text from PDF, PPT, or DOCX using Marker"""
    if not MARKER_AVAILABLE:
        return "Marker library not available for document processing.", ""
    
    try:
        # Create converter with model artifacts
        converter = PdfConverter(
            artifact_dict=create_model_dict(),
        )
        
        # Convert document
        rendered = converter(file_path)
        
        # Extract text from rendered output
        text, _, images = text_from_rendered(rendered)
        
        # Get basic stats
        word_count = len(text.split())
        char_count = len(text)
        
        stats = f"Extracted text ({word_count} words, {char_count} characters)"
        
        return stats, text
    
    except Exception as e:
        error_msg = f"Error processing document: {str(e)}"
        return error_msg, ""

def process_user_input(message, history, uploaded_files, file_cache):
    """Process user input and generate AI response using requirements gathering prompt"""
    
    # Build conversation history from chat history
    conversation_history = ""
    if history:
        for i, (user_msg, ai_msg) in enumerate(history):
            conversation_history += f"User: {user_msg}\n"
            if ai_msg:
                conversation_history += f"Assistant: {ai_msg}\n"
    
    # Add file information to conversation if files are uploaded
    if uploaded_files:
        file_info = f"\n[UPLOADED_FILES]\n"
        new_file_cache = file_cache.copy() if file_cache else {}
        
        for file_path in uploaded_files:
            try:
                file_name = file_path.split('/')[-1]
                file_extension = os.path.splitext(file_name)[1].lower()
                file_hash = get_file_hash(file_path)
                cache_key = f"{file_name}_{file_hash}"
                
                # Handle CSV files
                if file_extension == '.csv':
                    df = pd.read_csv(file_path)
                    file_info += f"- {file_name}: CSV file with {len(df)} rows and {len(df.columns)} columns\n"
                    file_info += f"  Columns: {', '.join(df.columns.tolist())}\n"
                
                # Handle Excel files
                elif file_extension in ['.xlsx', '.xls']:
                    df = pd.read_excel(file_path)
                    file_info += f"- {file_name}: Excel file with {len(df)} rows and {len(df.columns)} columns\n"
                    file_info += f"  Columns: {', '.join(df.columns.tolist())}\n"
                
                # Handle document files with Marker (PDF, PPT, DOCX)
                elif file_extension in ['.pdf', '.ppt', '.pptx', '.doc', '.docx']:
                    file_size = os.path.getsize(file_path)
                    file_size_mb = round(file_size / (1024 * 1024), 2)
                    
                    # Check if file is already processed and cached
                    if cache_key in new_file_cache:
                        # Use cached text
                        extraction_stats = new_file_cache[cache_key]['stats']
                        extracted_text = new_file_cache[cache_key]['text']
                        status = "(cached)"
                    else:
                        # Process new file with Marker
                        extraction_stats, extracted_text = extract_text_with_marker(file_path)
                        # Cache the results
                        new_file_cache[cache_key] = {
                            'stats': extraction_stats,
                            'text': extracted_text,
                            'file_name': file_name,
                            'file_path': file_path
                        }
                        status = "(newly processed)"
                    
                    # Determine document type
                    if file_extension == '.pdf':
                        doc_type = "PDF document"
                    elif file_extension in ['.ppt', '.pptx']:
                        doc_type = "PowerPoint presentation"
                    else:
                        doc_type = "Word document"
                    
                    file_info += f"- {file_name}: {doc_type}, Size: {file_size_mb} MB {status}\n"
                    file_info += f"  Content: {extraction_stats}\n"
                    
                    # Include extracted text in conversation context for better AI understanding
                    if extracted_text and len(extracted_text.strip()) > 0:
                        # Truncate very long texts for context (keep first 2000 chars)
                        text_preview = extracted_text[:200000] + "..." if len(extracted_text) > 200000 else extracted_text
                        file_info += f"  Text Preview: {text_preview}\n"
                
                # Handle image files
                elif file_extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp']:
                    with Image.open(file_path) as img:
                        width, height = img.size
                        mode = img.mode
                        file_size = os.path.getsize(file_path)
                        file_size_mb = round(file_size / (1024 * 1024), 2)
                    file_info += f"- {file_name}: {file_extension.upper()[1:]} image file\n"
                    file_info += f"  Dimensions: {width}x{height} pixels, Mode: {mode}, Size: {file_size_mb} MB\n"
                
                # Handle JSON files
                elif file_extension == '.json':
                    file_size = os.path.getsize(file_path)
                    file_size_kb = round(file_size / 1024, 2)
                    file_info += f"- {file_name}: JSON file, Size: {file_size_kb} KB\n"
                
                # Handle text files
                elif file_extension == '.txt':
                    with open(file_path, 'r', encoding='utf-8') as f:
                        lines = len(f.readlines())
                    file_size = os.path.getsize(file_path)
                    file_size_kb = round(file_size / 1024, 2)
                    file_info += f"- {file_name}: Text file with {lines} lines, Size: {file_size_kb} KB\n"
                
                # Handle other files
                else:
                    file_size = os.path.getsize(file_path)
                    file_size_kb = round(file_size / 1024, 2)
                    file_info += f"- {file_name}: File uploaded, Size: {file_size_kb} KB\n"
                    
            except Exception as e:
                file_info += f"- {file_path.split('/')[-1]}: File uploaded (unable to preview: {str(e)})\n"
                print(f"Error processing file {file_path}: {traceback.format_exc()}")
        
        conversation_history += file_info
        
        # Update the cache
        file_cache.update(new_file_cache)
    
    # Format the prompt with conversation history and current query
    formatted_prompt = requirements_gathering_system_prompt.format(
        conversation_history=conversation_history,
        query=message
    )
    
    # Get AI response
    ai_response = get_response(formatted_prompt)
    
    return ai_response, file_cache

def chat_interface(message, history, uploaded_files, file_cache):
    """Main chat interface function"""
    
    # Get AI response with updated cache
    ai_response, updated_cache = process_user_input(message, history, uploaded_files, file_cache)
    
    # Add to history
    history.append((message, ai_response))
    
    return history, history, "", updated_cache

def clear_chat():
    """Clear the chat history and file cache"""
    return [], [], {}

def upload_file_handler(files):
    """Handle file uploads"""
    if files:
        return files
    return []

def generate_plan(history, file_cache):
    """Generate a plan using the planning prompt and Gemini API"""
    
    # Build conversation history
    conversation_history = ""
    if history:
        for user_msg, ai_msg in history:
            conversation_history += f"User: {user_msg}\n"
            if ai_msg:
                conversation_history += f"Assistant: {ai_msg}\n"
    
    # Format the prompt
    formatted_prompt = hf_query_gen_prompt + "\n\n" + conversation_history
    
    # Get plan from Gemini
    plan = generate_with_gemini(formatted_prompt, "Planning with gemini")
    return plan

# Custom CSS for a sleek design
custom_css = """
.gradio-container {
    max-width: 900px !important;
    margin: auto !important;
}

.chat-container {
    height: 600px !important;
}

#component-0 {
    height: 100vh;
}

.message {
    padding: 15px !important;
    margin: 10px 0 !important;
    border-radius: 15px !important;
}

.user-message {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
    color: white !important;
    margin-left: 20% !important;
}

.bot-message {
    background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important;
    color: white !important;
    margin-right: 20% !important;
}

.upload-area {
    border: 2px dashed #4f46e5 !important;
    border-radius: 10px !important;
    padding: 20px !important;
    text-align: center !important;
    background: linear-gradient(135deg, #f0f4ff 0%, #e0e7ff 100%) !important;
}

.btn-primary {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
    border: none !important;
    border-radius: 25px !important;
    padding: 10px 25px !important;
    font-weight: bold !important;
}

.btn-secondary {
    background: linear-gradient(135deg, #ffeaa7 0%, #fab1a0 100%) !important;
    border: none !important;
    border-radius: 25px !important;
    padding: 10px 25px !important;
    font-weight: bold !important;
    color: #2d3436 !important;
}

.title {
    text-align: center !important;
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
    -webkit-background-clip: text !important;
    -webkit-text-fill-color: transparent !important;
    font-size: 2.5em !important;
    font-weight: bold !important;
    margin-bottom: 20px !important;
}

.subtitle {
    text-align: center !important;
    color: #6c757d !important;
    font-size: 1.2em !important;
    margin-bottom: 30px !important;
}
"""

# Create the Gradio interface
with gr.Blocks(css=custom_css, title="Data Science Requirements Gathering Agent") as app:
    
    # Header
    gr.HTML("""
        <div class="title">πŸ”¬ Data Science Consultant</div>
        <div class="subtitle">
            Transform your vague ideas into reality
        </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=3):
            # Chat interface
            chatbot = gr.Chatbot(
                label="Requirements Gathering Conversation",
                height=500,
                show_copy_button=True,
                bubble_full_width=False,
                elem_classes=["chat-container"]
            )

            plan_output = gr.Textbox(
                            label="Generated Plan",
                            interactive=False,
                            visible=True,
                            lines=10,
                            max_lines=20
                        )
            
            with gr.Row():
                with gr.Column(scale=4):
                    msg = gr.Textbox(
                        placeholder="Describe your data science project or ask a question...",
                        label="Your Message",
                        lines=2,
                        max_lines=5
                    )
                with gr.Column(scale=1):
                    send_btn = gr.Button("Send πŸ“€", variant="primary", elem_classes=["btn-primary"])

                with gr.Column(scale=1):
                    plan_btn = gr.Button("Generate Plan πŸ“‹", variant="secondary", elem_classes=["btn-secondary"])

            
            with gr.Row():
                clear_btn = gr.Button("Clear Chat πŸ—‘οΈ", variant="secondary", elem_classes=["btn-secondary"])
        
        with gr.Column(scale=1):
            # File upload section
            gr.HTML("<h3 style='text-align: center; color: #4f46e5;'>πŸ“ Upload Data Files</h3>")
            
            file_upload = gr.File(
                label="Upload your files (CSV, Excel, PDF, PPT, DOCX, Images, etc.)",
                file_count="multiple",
                file_types=[".csv", ".xlsx", ".xls", ".json", ".txt", ".pdf", ".ppt", ".pptx", ".doc", ".docx", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"],
                elem_classes=["upload-area"]
            )
            
            uploaded_files_display = gr.File(
                label="Uploaded Files",
                file_count="multiple",
                interactive=False,
                visible=True
            )
            
            # Instructions
            gr.HTML("""
                <div style="padding: 15px; background: linear-gradient(135deg, #e3f2fd 0%, #f3e5f5 100%); 
                           border-radius: 10px; margin-top: 20px;">
                    <h4 style="color: #4f46e5; margin-bottom: 10px;">πŸ’‘ How it works:</h4>
                    <ol style="color: #555; font-size: 14px; line-height: 1.6;">
                        <li>Describe your data science project</li>
                        <li>Upload your files (data, documents, images)</li>
                        <li>Answer clarifying questions</li>
                        <li>Get a complete task specification</li>
                    </ol>
                    <p style="color: #666; font-size: 12px; margin-top: 10px;">
                        πŸ“„ Supports: CSV, Excel, PDF, PowerPoint, Word docs, Images, JSON, Text files
                    </p>
                </div>
            """)
    
    # State for conversation history and file cache
    chat_history = gr.State([])
    file_cache = gr.State({})
    
    # Event handlers
    def handle_send(message, history, files, cache):
        if message.strip():
            new_history, updated_history, cleared_input, updated_cache = chat_interface(message, history, files, cache)
            return new_history, updated_history, cleared_input, updated_cache
        return history, history, message, cache
    
    # Wire up the interface
    send_btn.click(
        handle_send,
        inputs=[msg, chat_history, uploaded_files_display, file_cache],
        outputs=[chatbot, chat_history, msg, file_cache]
    )
    
    msg.submit(
        handle_send,
        inputs=[msg, chat_history, uploaded_files_display, file_cache],
        outputs=[chatbot, chat_history, msg, file_cache]
    )
    
    clear_btn.click(
        clear_chat,
        outputs=[chatbot, chat_history, file_cache]
    )

    plan_btn.click(
        generate_plan,
        inputs=[chat_history, file_cache],
        outputs=[plan_output]
    )
    
    file_upload.change(
        lambda files: files,
        inputs=[file_upload],
        outputs=[uploaded_files_display]
    )
    
    # Welcome message
    app.load(
        lambda: [(None, "πŸ‘‹ Hello! I'm your Data Science Project Agent. I'll help you transform your project ideas into reality  .\n\nπŸš€ **Let's get started!** Tell me about your data science project or what you're trying to achieve.")],
        outputs=[chatbot]
    )

if __name__ == "__main__":
    app.launch(share=True, show_error=True)