File size: 17,672 Bytes
57c13e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f253a30
57c13e3
 
f253a30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57c13e3
 
 
 
 
 
 
 
f253a30
 
57c13e3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
"""
Interactive setup and chat interface for DeepDrone.
"""

import os
import sys
import asyncio
from typing import Dict, Optional, Tuple, List
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.align import Align
from rich.prompt import Prompt, Confirm
from rich.table import Table
from rich.live import Live
from rich.layout import Layout
from rich.spinner import Spinner
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog, message_dialog
from prompt_toolkit.styles import Style
import getpass

from .config import ModelConfig
from .drone_chat_interface import DroneChatInterface

console = Console()

# Provider configurations
PROVIDERS = {
    "OpenAI": {
        "name": "openai",
        "models": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"],
        "api_key_url": "https://platform.openai.com/api-keys",
        "description": "GPT models from OpenAI"
    },
    "Anthropic": {
        "name": "anthropic",
        "models": ["claude-3-5-sonnet-20241022", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"],
        "api_key_url": "https://console.anthropic.com/",
        "description": "Claude models from Anthropic"
    },
    "Google": {
        "name": "vertex_ai",
        "models": ["gemini-1.5-pro", "gemini-1.5-flash", "gemini-pro"],
        "api_key_url": "https://console.cloud.google.com/",
        "description": "Gemini models from Google"
    },
    "Meta": {
        "name": "openai",  # Using OpenAI format for Llama models via providers
        "models": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct"],
        "api_key_url": "https://together.ai/ or https://replicate.com/",
        "description": "Llama models from Meta (via Together.ai/Replicate)"
    },
    "Mistral": {
        "name": "mistral",
        "models": ["mistral-large-latest", "mistral-medium-latest", "mistral-small-latest"],
        "api_key_url": "https://console.mistral.ai/",
        "description": "Mistral AI models"
    },
    "Ollama": {
        "name": "ollama",
        "models": ["llama3.1:latest", "codestral:latest", "qwen2.5-coder:latest", "phi3:latest"],
        "api_key_url": "https://ollama.ai/ (No API key needed - runs locally)",
        "description": "Local models via Ollama (no API key required)"
    }
}

def show_welcome_banner():
    """Display the welcome banner."""
    banner = """
╔══════════════════════════════════════════════════════════╗
β•‘                                                          β•‘
β•‘           🚁 DEEPDRONE AI CONTROL SYSTEM 🚁              β•‘
β•‘                                                          β•‘
β•‘        Advanced Drone Control with AI Integration        β•‘
β•‘                                                          β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
    """
    
    console.print(Panel(
        Align.center(Text(banner.strip(), style="bold green")),
        border_style="bright_green",
        padding=(1, 2)
    ))

def select_provider() -> Optional[Tuple[str, Dict]]:
    """Interactive provider selection."""
    console.print("\n[bold cyan]πŸ“‘ Select AI Provider[/bold cyan]\n")
    
    # Create provider table for display
    table = Table(show_header=True, header_style="bold magenta")
    table.add_column("β„–", style="bright_green", width=3)
    table.add_column("Provider", style="cyan", width=12)
    table.add_column("Description", style="white")
    table.add_column("Example Models", style="yellow")
    
    provider_list = list(PROVIDERS.items())
    
    for i, (name, config) in enumerate(provider_list, 1):
        example_models = ", ".join(config["models"][:2])
        if len(config["models"]) > 2:
            example_models += "..."
        table.add_row(str(i), name, config["description"], example_models)
    
    console.print(table)
    console.print()
    
    try:
        from rich.prompt import IntPrompt
        
        choice = IntPrompt.ask(
            "Select provider by number",
            choices=[str(i) for i in range(1, len(provider_list) + 1)],
            default=1
        )
        
        provider_name, provider_config = provider_list[choice - 1]
        return provider_name, provider_config
        
    except KeyboardInterrupt:
        console.print("\n[yellow]Selection cancelled[/yellow]")
        return None

def get_available_ollama_models() -> List[str]:
    """Get list of locally available Ollama models."""
    try:
        import ollama
        models = ollama.list()
        # The models are returned as Model objects with a 'model' attribute
        return [model.model for model in models.models] if hasattr(models, 'models') else []
    except ImportError:
        return []
    except Exception as e:
        # For debugging, you can uncomment the next line
        # print(f"Error getting Ollama models: {e}")
        return []

def install_ollama_model(model_name: str) -> bool:
    """Install an Ollama model."""
    try:
        import ollama
        console.print(f"[yellow]πŸ“₯ Installing {model_name}... This may take a few minutes.[/yellow]")
        
        with Live(
            Spinner("dots", text=f"Installing {model_name}..."),
            console=console,
            transient=True
        ) as live:
            ollama.pull(model_name)
            live.stop()
        
        console.print(f"[green]βœ… Successfully installed {model_name}[/green]")
        return True
    except ImportError:
        console.print("[red]❌ Ollama package not installed[/red]")
        return False
    except Exception as e:
        console.print(f"[red]❌ Failed to install {model_name}: {e}[/red]")
        return False

def get_model_name(provider_name: str, provider_config: Dict) -> Optional[str]:
    """Get model name from user."""
    console.print(f"\n[bold cyan]πŸ€– Select Model for {provider_name}[/bold cyan]\n")
    
    # Special handling for Ollama
    if provider_name.lower() == "ollama":
        # Check if Ollama is running and get local models
        local_models = get_available_ollama_models()
        
        if local_models:
            console.print("[bold green]βœ… Local Ollama models found:[/bold green]")
            for i, model in enumerate(local_models, 1):
                console.print(f"  {i}. [green]{model}[/green]")
            
            console.print("\n[bold]Popular models (if not installed locally):[/bold]")
            start_idx = len(local_models) + 1
            for i, model in enumerate(provider_config["models"], start_idx):
                console.print(f"  {i}. [blue]{model}[/blue] [dim](will be downloaded)[/dim]")
            
            all_options = local_models + provider_config["models"]
            
        else:
            console.print("[yellow]⚠️  No local Ollama models found or Ollama not running[/yellow]")
            console.print("Make sure Ollama is running: [cyan]ollama serve[/cyan]\n")
            console.print("[bold]Popular models (will be downloaded):[/bold]")
            all_options = provider_config["models"]
            for i, model in enumerate(all_options, 1):
                console.print(f"  {i}. [blue]{model}[/blue] [dim](will be downloaded)[/dim]")
        
        console.print(f"\n[dim]Download from: {provider_config['api_key_url']}[/dim]\n")
        
        try:
            from rich.prompt import Prompt
            
            result = Prompt.ask(
                "Enter model name or number from list above",
                default="1"
            )
            
            if result:
                # Check if user entered a number (selecting from list)
                try:
                    choice_num = int(result.strip())
                    if 1 <= choice_num <= len(all_options):
                        selected_model = all_options[choice_num - 1]
                        
                        # Check if model needs to be installed
                        if selected_model not in local_models:
                            console.print(f"[yellow]Model '{selected_model}' not found locally.[/yellow]")
                            from rich.prompt import Confirm
                            if Confirm.ask(f"Would you like to install {selected_model}?", default=True):
                                if install_ollama_model(selected_model):
                                    return selected_model
                                else:
                                    return None
                            else:
                                console.print("[yellow]Model installation cancelled[/yellow]")
                                return None
                        
                        return selected_model
                except ValueError:
                    pass
                
                # User entered a custom model name
                model_name = result.strip()
                if model_name not in local_models:
                    console.print(f"[yellow]Model '{model_name}' not found locally.[/yellow]")
                    from rich.prompt import Confirm
                    if Confirm.ask(f"Would you like to install {model_name}?", default=True):
                        if install_ollama_model(model_name):
                            return model_name
                        else:
                            return None
                    else:
                        console.print("[yellow]Model installation cancelled[/yellow]")
                        return None
                
                return model_name
            
            return None
            
        except KeyboardInterrupt:
            console.print("\n[yellow]Input cancelled[/yellow]")
            return None
    
    else:
        # Standard handling for other providers
        console.print("[bold]Popular models for this provider:[/bold]")
        for i, model in enumerate(provider_config["models"], 1):
            console.print(f"  {i}. [green]{model}[/green]")
        
        console.print(f"\n[dim]Get API key from: {provider_config['api_key_url']}[/dim]\n")
        
        try:
            from rich.prompt import Prompt
            
            result = Prompt.ask(
                "Enter model name or number from list above",
                default="1"
            )
            
            if result:
                # Check if user entered a number (selecting from list)
                try:
                    choice_num = int(result.strip())
                    if 1 <= choice_num <= len(provider_config["models"]):
                        return provider_config["models"][choice_num - 1]
                except ValueError:
                    pass
                
                # Return the entered model name
                return result.strip()
            
            return None
            
        except KeyboardInterrupt:
            console.print("\n[yellow]Input cancelled[/yellow]")
            return None

def get_api_key(provider_name: str, model_name: str) -> Optional[str]:
    """Get API key from user."""
    console.print(f"\n[bold cyan]πŸ”‘ API Key for {provider_name}[/bold cyan]\n")
    console.print(f"Model: [green]{model_name}[/green]")
    console.print(f"Provider: [blue]{provider_name}[/blue]\n")
    
    # Ollama doesn't need an API key
    if provider_name.lower() == "ollama":
        console.print("[green]βœ… Ollama runs locally - no API key required![/green]")
        console.print("[dim]Make sure Ollama is running: ollama serve[/dim]\n")
        return "local"  # Return a placeholder value
    
    try:
        # Use getpass for secure password input (works in all environments)
        api_key = getpass.getpass("Enter your API key (hidden): ")
        
        if api_key and api_key.strip():
            return api_key.strip()
        
        console.print("[yellow]No API key provided[/yellow]")
        return None
        
    except KeyboardInterrupt:
        console.print("\n[yellow]Input cancelled[/yellow]")
        return None

def test_model_connection(model_config: ModelConfig) -> bool:
    """Test if the model configuration works."""
    console.print(f"\n[yellow]πŸ” Testing connection to {model_config.name}...[/yellow]")
    
    try:
        from .llm_interface import LLMInterface
        
        with Live(
            Spinner("dots", text="Testing API connection..."),
            console=console,
            transient=True
        ) as live:
            llm = LLMInterface(model_config)
            result = llm.test_connection()
            
            live.stop()
            
            if result["success"]:
                console.print("[green]βœ… Connection successful![/green]")
                console.print(f"[dim]Response: {result['response'][:100]}...[/dim]\n")
                return True
            else:
                console.print(f"[red]❌ Connection failed: {result['error']}[/red]\n")
                return False
                
    except Exception as e:
        console.print(f"[red]❌ Error testing connection: {e}[/red]\n")
        return False

def start_interactive_session():
    """Start the interactive setup and chat session."""
    try:
        # Show welcome banner
        show_welcome_banner()
        
        # Step 1: Select provider
        console.print("[bold]Step 1: Choose your AI provider[/bold]\n")
        provider_result = select_provider()
        if not provider_result:
            console.print("[yellow]Setup cancelled. Goodbye![/yellow]")
            return
        
        provider_name, provider_config = provider_result
        
        # Step 2: Get model name
        console.print(f"[bold]Step 2: Select model for {provider_name}[/bold]")
        model_name = get_model_name(provider_name, provider_config)
        if not model_name:
            console.print("[yellow]Setup cancelled. Goodbye![/yellow]")
            return
        
        # Step 3: Get API key
        console.print("[bold]Step 3: Enter API key[/bold]")
        api_key = get_api_key(provider_name, model_name)
        if not api_key:
            console.print("[yellow]Setup cancelled. Goodbye![/yellow]")
            return
        
        # Create model configuration
        base_url = None
        if provider_name.lower() == "ollama":
            base_url = "http://localhost:11434"
        
        model_config = ModelConfig(
            name=f"{provider_name.lower()}-session",
            provider=provider_config["name"],
            model_id=model_name,
            api_key=api_key,
            base_url=base_url,
            max_tokens=2048,
            temperature=0.7
        )
        
        # Step 4: Test connection
        console.print("[bold]Step 4: Testing connection[/bold]")
        if not test_model_connection(model_config):
            if not Confirm.ask("Connection test failed. Continue anyway?"):
                console.print("[yellow]Setup cancelled. Goodbye![/yellow]")
                return
        
        # Step 5: Get drone connection string
        console.print("[bold yellow]🚁 Drone Connection Setup[/bold yellow]\n")
        
        # Check if simulator is already running
        import subprocess
        try:
            result = subprocess.run(['ps', 'aux'], capture_output=True, text=True)
            if 'mavproxy' in result.stdout.lower() or 'sitl' in result.stdout.lower():
                console.print("[green]βœ… Detected running drone simulator![/green]")
                default_connection = "udp:127.0.0.1:14550"
            else:
                console.print("[yellow]⚠️  No simulator detected[/yellow]")
                default_connection = "udp:127.0.0.1:14550"
        except:
            default_connection = "udp:127.0.0.1:14550"
        
        console.print("Connection options:")
        console.print("  β€’ [green]Simulator[/green]: [cyan]udp:127.0.0.1:14550[/cyan] (default)")
        console.print("  β€’ [blue]Real Drone USB[/blue]: [cyan]/dev/ttyACM0[/cyan] (Linux) or [cyan]COM3[/cyan] (Windows)")
        console.print("  β€’ [blue]Real Drone TCP[/blue]: [cyan]tcp:192.168.1.100:5760[/cyan]")
        console.print("  β€’ [blue]Real Drone UDP[/blue]: [cyan]udp:192.168.1.100:14550[/cyan]\n")
        
        from rich.prompt import Prompt
        connection_string = Prompt.ask(
            "Enter drone connection string",
            default=default_connection
        )
        
        if not connection_string:
            console.print("[yellow]Using default connection: udp:127.0.0.1:14550[/yellow]")
            connection_string = "udp:127.0.0.1:14550"
        
        console.print(f"[dim]Will connect to: {connection_string}[/dim]\n")
        
        # Step 6: Start chat
        console.print("[bold green]πŸš€ Starting DeepDrone chat session...[/bold green]\n")
        
        # Small delay
        import time
        time.sleep(1)
        
        # Start the chat interface with the connection string
        chat_interface = DroneChatInterface(model_config, connection_string)
        chat_interface.start()
        
    except KeyboardInterrupt:
        console.print("\n[yellow]🚁 DeepDrone session interrupted. Goodbye![/yellow]")
        sys.exit(0)
    except Exception as e:
        console.print(f"[red]❌ Error in interactive session: {e}[/red]")
        sys.exit(1)