"""Ollama adapter.""" import asyncio import os import shutil import subprocess import time from typing import Any, Dict, List import aiohttp from starfish.common.logger import get_logger logger = get_logger(__name__) # Default Ollama connection settings OLLAMA_HOST = "localhost" OLLAMA_PORT = 11434 OLLAMA_BASE_URL = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}" class OllamaError(Exception): """Base exception for Ollama-related errors.""" pass class OllamaNotInstalledError(OllamaError): """Error raised when Ollama is not installed.""" pass class OllamaConnectionError(OllamaError): """Error raised when connection to Ollama server fails.""" pass async def is_ollama_running() -> bool: """Check if Ollama server is running.""" try: async with aiohttp.ClientSession() as session: async with session.get(f"{OLLAMA_BASE_URL}/api/version", timeout=aiohttp.ClientTimeout(total=2)) as response: return response.status == 200 except Exception: return False async def start_ollama_server() -> bool: """Start the Ollama server if it's not already running.""" # Check if already running if await is_ollama_running(): logger.info("Ollama server is already running") return True # Find the ollama executable ollama_bin = shutil.which("ollama") if not ollama_bin: logger.error("Ollama is not installed") raise OllamaNotInstalledError("Ollama is not installed. Please install from https://ollama.com/download") logger.info("Starting Ollama server...") try: # Start Ollama as a detached process if os.name == "nt": # Windows subprocess.Popen([ollama_bin, "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP) else: # Unix/Linux/Mac subprocess.Popen( [ollama_bin, "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setpgrp, # Run in a new process group ) # Wait for the server to start for _ in range(10): if await is_ollama_running(): logger.info("Ollama server started successfully") return True await asyncio.sleep(0.5) logger.error("Timed out waiting for Ollama server to start") return False except Exception as e: logger.error(f"Failed to start Ollama server: {e}") return False async def list_models() -> List[Dict[str, Any]]: """List available models in Ollama using the API.""" try: async with aiohttp.ClientSession() as session: async with session.get(f"{OLLAMA_BASE_URL}/api/tags") as response: if response.status != 200: logger.error(f"Error listing models: {response.status}") return [] data = await response.json() return data.get("models", []) except Exception as e: logger.error(f"Error listing models: {e}") return [] async def is_model_available(model_name: str) -> bool: """Check if model is available using the CLI command.""" try: # Use CLI for more reliable checking ollama_bin = shutil.which("ollama") if not ollama_bin: logger.error("Ollama binary not found") return False process = await asyncio.create_subprocess_exec(ollama_bin, "list", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) stdout, _ = await process.communicate() output = stdout.decode().strip() # Log what models are available logger.debug(f"Available models: {output}") # Look for the model name in the output model_lines = output.split("\n") for line in model_lines: # Check if the line contains the model name if line.strip() and model_name in line.split()[0]: logger.info(f"Found model {model_name}") return True logger.info(f"Model {model_name} not found") return False except Exception as e: logger.error(f"Error checking if model is available: {e}") return False async def pull_model(model_name: str) -> bool: """Pull a model using the Ollama CLI. This is more reliable than the API for large downloads. """ # Use the Ollama CLI directly for more reliable downloads ollama_bin = shutil.which("ollama") if not ollama_bin: logger.error("Ollama binary not found") return False try: # Set logging interval LOG_INTERVAL = 10 # Only log every 10 seconds logger.info(f"Pulling model {model_name}... (progress updates every {LOG_INTERVAL} seconds)") # Create the subprocess process = await asyncio.create_subprocess_exec(ollama_bin, "pull", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) # Track last progress output time to throttle logging last_log_time = 0 # Define functions to read from both stdout and stderr async def read_stream(stream, stream_name): nonlocal last_log_time buffered_content = [] while True: line = await stream.readline() if not line: break line_text = line.decode().strip() if not line_text: continue # Fix truncated "ulling manifest" text if present if "ulling manifest" in line_text and not line_text.startswith("Pulling"): line_text = "P" + line_text current_time = time.time() # Always log important messages like errors if "error" in line_text.lower() or "fail" in line_text.lower(): logger.info(f"Ollama pull ({stream_name}): {line_text}") continue # Buffer regular progress messages buffered_content.append(line_text) # Log throttled progress updates if current_time - last_log_time >= LOG_INTERVAL: if buffered_content: logger.info(f"Ollama pull progress: {buffered_content[-1]}") buffered_content = [] last_log_time = current_time # Read from both stdout and stderr concurrently await asyncio.gather(read_stream(process.stdout, "stdout"), read_stream(process.stderr, "stderr")) # Wait for process to complete exit_code = await process.wait() if exit_code == 0: logger.info(f"Successfully pulled model {model_name}") # Give a moment for Ollama to finalize the model await asyncio.sleep(1) # Verify model is available if await is_model_available(model_name): logger.info(f"Verified model {model_name} is now available") return True else: logger.error(f"Model pull completed but {model_name} not found in list") return False else: logger.error(f"Failed to pull model {model_name} with exit code {exit_code}") return False except Exception as e: logger.error(f"Error pulling model {model_name}: {e}") return False async def ensure_model_ready(model_name: str) -> bool: """Ensure Ollama server is running and the model is available.""" # Step 1: Make sure Ollama server is running if not await start_ollama_server(): logger.error("Failed to start Ollama server") return False # Step 2: Check if model is already available if await is_model_available(model_name): logger.info(f"Model {model_name} is already available") return True # Step 3: Pull the model if not available logger.info(f"Model {model_name} not found, downloading...") if await pull_model(model_name): logger.info(f"Model {model_name} successfully pulled and ready") return True else: logger.error(f"Failed to pull model {model_name}") return False async def stop_ollama_server() -> bool: """Stop the Ollama server.""" try: # Find the ollama executable (just to check if it's installed) ollama_bin = shutil.which("ollama") if not ollama_bin: logger.error("Ollama is not installed") return False logger.info("Stopping Ollama server...") # Different process termination based on platform if os.name == "nt": # Windows # Windows uses taskkill to terminate processes process = await asyncio.create_subprocess_exec( "taskkill", "/F", "/IM", "ollama.exe", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) else: # Unix/Linux/Mac # Use pkill to terminate all Ollama processes process = await asyncio.create_subprocess_exec("pkill", "-f", "ollama", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) _, stderr = await process.communicate() # Check if the Ollama server is still running if await is_ollama_running(): logger.error(f"Failed to stop Ollama server: {stderr.decode() if stderr else 'unknown error'}") logger.info("Attempting stronger termination...") # Try one more time with stronger termination if it's still running if os.name == "nt": # Windows process = await asyncio.create_subprocess_exec( "taskkill", "/F", "/IM", "ollama.exe", "/T", # /T terminates child processes as well stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) else: # Unix/Linux/Mac process = await asyncio.create_subprocess_exec( "pkill", "-9", "-f", "ollama", # SIGKILL for force termination stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) _, stderr = await process.communicate() if await is_ollama_running(): logger.error("Failed to forcefully stop Ollama server") return False # Wait a moment to ensure processes are actually terminated await asyncio.sleep(1) # Verify the server is no longer running if not await is_ollama_running(): logger.info("Ollama server stopped successfully") return True else: logger.error("Failed to stop Ollama server: still running after termination attempts") return False except Exception as e: logger.error(f"Error stopping Ollama server: {str(e)}") cmd = "taskkill /F /IM ollama.exe" if os.name == "nt" else "pkill -f ollama" logger.error(f"Command attempted: {cmd}") return False async def delete_model(model_name: str) -> bool: """Delete a model from Ollama. Args: model_name: The name of the model to delete Returns: bool: True if deletion was successful, False otherwise """ try: # Find the ollama executable ollama_bin = shutil.which("ollama") if not ollama_bin: logger.error("Ollama is not installed") return False logger.info(f"Deleting model {model_name} from Ollama...") process = await asyncio.create_subprocess_exec(ollama_bin, "rm", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) stdout, stderr = await process.communicate() if process.returncode != 0: logger.error(f"Failed to delete model {model_name}: {stderr.decode()}") return False logger.info(f"Model {model_name} deleted successfully") return True except Exception as e: logger.error(f"Error deleting model {model_name}: {e}") return False