John-Jiang's picture
init commit
5301c48
"""Ollama adapter."""
import asyncio
import os
import shutil
import subprocess
import time
from typing import Any, Dict, List
import aiohttp
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Default Ollama connection settings
OLLAMA_HOST = "localhost"
OLLAMA_PORT = 11434
OLLAMA_BASE_URL = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}"
class OllamaError(Exception):
"""Base exception for Ollama-related errors."""
pass
class OllamaNotInstalledError(OllamaError):
"""Error raised when Ollama is not installed."""
pass
class OllamaConnectionError(OllamaError):
"""Error raised when connection to Ollama server fails."""
pass
async def is_ollama_running() -> bool:
"""Check if Ollama server is running."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{OLLAMA_BASE_URL}/api/version", timeout=aiohttp.ClientTimeout(total=2)) as response:
return response.status == 200
except Exception:
return False
async def start_ollama_server() -> bool:
"""Start the Ollama server if it's not already running."""
# Check if already running
if await is_ollama_running():
logger.info("Ollama server is already running")
return True
# Find the ollama executable
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
raise OllamaNotInstalledError("Ollama is not installed. Please install from https://ollama.com/download")
logger.info("Starting Ollama server...")
try:
# Start Ollama as a detached process
if os.name == "nt": # Windows
subprocess.Popen([ollama_bin, "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
else: # Unix/Linux/Mac
subprocess.Popen(
[ollama_bin, "serve"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setpgrp, # Run in a new process group
)
# Wait for the server to start
for _ in range(10):
if await is_ollama_running():
logger.info("Ollama server started successfully")
return True
await asyncio.sleep(0.5)
logger.error("Timed out waiting for Ollama server to start")
return False
except Exception as e:
logger.error(f"Failed to start Ollama server: {e}")
return False
async def list_models() -> List[Dict[str, Any]]:
"""List available models in Ollama using the API."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{OLLAMA_BASE_URL}/api/tags") as response:
if response.status != 200:
logger.error(f"Error listing models: {response.status}")
return []
data = await response.json()
return data.get("models", [])
except Exception as e:
logger.error(f"Error listing models: {e}")
return []
async def is_model_available(model_name: str) -> bool:
"""Check if model is available using the CLI command."""
try:
# Use CLI for more reliable checking
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama binary not found")
return False
process = await asyncio.create_subprocess_exec(ollama_bin, "list", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, _ = await process.communicate()
output = stdout.decode().strip()
# Log what models are available
logger.debug(f"Available models: {output}")
# Look for the model name in the output
model_lines = output.split("\n")
for line in model_lines:
# Check if the line contains the model name
if line.strip() and model_name in line.split()[0]:
logger.info(f"Found model {model_name}")
return True
logger.info(f"Model {model_name} not found")
return False
except Exception as e:
logger.error(f"Error checking if model is available: {e}")
return False
async def pull_model(model_name: str) -> bool:
"""Pull a model using the Ollama CLI.
This is more reliable than the API for large downloads.
"""
# Use the Ollama CLI directly for more reliable downloads
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama binary not found")
return False
try:
# Set logging interval
LOG_INTERVAL = 10 # Only log every 10 seconds
logger.info(f"Pulling model {model_name}... (progress updates every {LOG_INTERVAL} seconds)")
# Create the subprocess
process = await asyncio.create_subprocess_exec(ollama_bin, "pull", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
# Track last progress output time to throttle logging
last_log_time = 0
# Define functions to read from both stdout and stderr
async def read_stream(stream, stream_name):
nonlocal last_log_time
buffered_content = []
while True:
line = await stream.readline()
if not line:
break
line_text = line.decode().strip()
if not line_text:
continue
# Fix truncated "ulling manifest" text if present
if "ulling manifest" in line_text and not line_text.startswith("Pulling"):
line_text = "P" + line_text
current_time = time.time()
# Always log important messages like errors
if "error" in line_text.lower() or "fail" in line_text.lower():
logger.info(f"Ollama pull ({stream_name}): {line_text}")
continue
# Buffer regular progress messages
buffered_content.append(line_text)
# Log throttled progress updates
if current_time - last_log_time >= LOG_INTERVAL:
if buffered_content:
logger.info(f"Ollama pull progress: {buffered_content[-1]}")
buffered_content = []
last_log_time = current_time
# Read from both stdout and stderr concurrently
await asyncio.gather(read_stream(process.stdout, "stdout"), read_stream(process.stderr, "stderr"))
# Wait for process to complete
exit_code = await process.wait()
if exit_code == 0:
logger.info(f"Successfully pulled model {model_name}")
# Give a moment for Ollama to finalize the model
await asyncio.sleep(1)
# Verify model is available
if await is_model_available(model_name):
logger.info(f"Verified model {model_name} is now available")
return True
else:
logger.error(f"Model pull completed but {model_name} not found in list")
return False
else:
logger.error(f"Failed to pull model {model_name} with exit code {exit_code}")
return False
except Exception as e:
logger.error(f"Error pulling model {model_name}: {e}")
return False
async def ensure_model_ready(model_name: str) -> bool:
"""Ensure Ollama server is running and the model is available."""
# Step 1: Make sure Ollama server is running
if not await start_ollama_server():
logger.error("Failed to start Ollama server")
return False
# Step 2: Check if model is already available
if await is_model_available(model_name):
logger.info(f"Model {model_name} is already available")
return True
# Step 3: Pull the model if not available
logger.info(f"Model {model_name} not found, downloading...")
if await pull_model(model_name):
logger.info(f"Model {model_name} successfully pulled and ready")
return True
else:
logger.error(f"Failed to pull model {model_name}")
return False
async def stop_ollama_server() -> bool:
"""Stop the Ollama server."""
try:
# Find the ollama executable (just to check if it's installed)
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
return False
logger.info("Stopping Ollama server...")
# Different process termination based on platform
if os.name == "nt": # Windows
# Windows uses taskkill to terminate processes
process = await asyncio.create_subprocess_exec(
"taskkill", "/F", "/IM", "ollama.exe", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
else: # Unix/Linux/Mac
# Use pkill to terminate all Ollama processes
process = await asyncio.create_subprocess_exec("pkill", "-f", "ollama", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
_, stderr = await process.communicate()
# Check if the Ollama server is still running
if await is_ollama_running():
logger.error(f"Failed to stop Ollama server: {stderr.decode() if stderr else 'unknown error'}")
logger.info("Attempting stronger termination...")
# Try one more time with stronger termination if it's still running
if os.name == "nt": # Windows
process = await asyncio.create_subprocess_exec(
"taskkill",
"/F",
"/IM",
"ollama.exe",
"/T", # /T terminates child processes as well
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
else: # Unix/Linux/Mac
process = await asyncio.create_subprocess_exec(
"pkill",
"-9",
"-f",
"ollama", # SIGKILL for force termination
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await process.communicate()
if await is_ollama_running():
logger.error("Failed to forcefully stop Ollama server")
return False
# Wait a moment to ensure processes are actually terminated
await asyncio.sleep(1)
# Verify the server is no longer running
if not await is_ollama_running():
logger.info("Ollama server stopped successfully")
return True
else:
logger.error("Failed to stop Ollama server: still running after termination attempts")
return False
except Exception as e:
logger.error(f"Error stopping Ollama server: {str(e)}")
cmd = "taskkill /F /IM ollama.exe" if os.name == "nt" else "pkill -f ollama"
logger.error(f"Command attempted: {cmd}")
return False
async def delete_model(model_name: str) -> bool:
"""Delete a model from Ollama.
Args:
model_name: The name of the model to delete
Returns:
bool: True if deletion was successful, False otherwise
"""
try:
# Find the ollama executable
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
return False
logger.info(f"Deleting model {model_name} from Ollama...")
process = await asyncio.create_subprocess_exec(ollama_bin, "rm", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.error(f"Failed to delete model {model_name}: {stderr.decode()}")
return False
logger.info(f"Model {model_name} deleted successfully")
return True
except Exception as e:
logger.error(f"Error deleting model {model_name}: {e}")
return False