John-Jiang's picture
init commit
5301c48
raw
history blame
12.3 kB
"""Ollama adapter."""
import asyncio
import os
import shutil
import subprocess
import time
from typing import Any, Dict, List
import aiohttp
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Default Ollama connection settings
OLLAMA_HOST = "localhost"
OLLAMA_PORT = 11434
OLLAMA_BASE_URL = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}"
class OllamaError(Exception):
"""Base exception for Ollama-related errors."""
pass
class OllamaNotInstalledError(OllamaError):
"""Error raised when Ollama is not installed."""
pass
class OllamaConnectionError(OllamaError):
"""Error raised when connection to Ollama server fails."""
pass
async def is_ollama_running() -> bool:
"""Check if Ollama server is running."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{OLLAMA_BASE_URL}/api/version", timeout=aiohttp.ClientTimeout(total=2)) as response:
return response.status == 200
except Exception:
return False
async def start_ollama_server() -> bool:
"""Start the Ollama server if it's not already running."""
# Check if already running
if await is_ollama_running():
logger.info("Ollama server is already running")
return True
# Find the ollama executable
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
raise OllamaNotInstalledError("Ollama is not installed. Please install from https://ollama.com/download")
logger.info("Starting Ollama server...")
try:
# Start Ollama as a detached process
if os.name == "nt": # Windows
subprocess.Popen([ollama_bin, "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
else: # Unix/Linux/Mac
subprocess.Popen(
[ollama_bin, "serve"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setpgrp, # Run in a new process group
)
# Wait for the server to start
for _ in range(10):
if await is_ollama_running():
logger.info("Ollama server started successfully")
return True
await asyncio.sleep(0.5)
logger.error("Timed out waiting for Ollama server to start")
return False
except Exception as e:
logger.error(f"Failed to start Ollama server: {e}")
return False
async def list_models() -> List[Dict[str, Any]]:
"""List available models in Ollama using the API."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{OLLAMA_BASE_URL}/api/tags") as response:
if response.status != 200:
logger.error(f"Error listing models: {response.status}")
return []
data = await response.json()
return data.get("models", [])
except Exception as e:
logger.error(f"Error listing models: {e}")
return []
async def is_model_available(model_name: str) -> bool:
"""Check if model is available using the CLI command."""
try:
# Use CLI for more reliable checking
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama binary not found")
return False
process = await asyncio.create_subprocess_exec(ollama_bin, "list", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, _ = await process.communicate()
output = stdout.decode().strip()
# Log what models are available
logger.debug(f"Available models: {output}")
# Look for the model name in the output
model_lines = output.split("\n")
for line in model_lines:
# Check if the line contains the model name
if line.strip() and model_name in line.split()[0]:
logger.info(f"Found model {model_name}")
return True
logger.info(f"Model {model_name} not found")
return False
except Exception as e:
logger.error(f"Error checking if model is available: {e}")
return False
async def pull_model(model_name: str) -> bool:
"""Pull a model using the Ollama CLI.
This is more reliable than the API for large downloads.
"""
# Use the Ollama CLI directly for more reliable downloads
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama binary not found")
return False
try:
# Set logging interval
LOG_INTERVAL = 10 # Only log every 10 seconds
logger.info(f"Pulling model {model_name}... (progress updates every {LOG_INTERVAL} seconds)")
# Create the subprocess
process = await asyncio.create_subprocess_exec(ollama_bin, "pull", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
# Track last progress output time to throttle logging
last_log_time = 0
# Define functions to read from both stdout and stderr
async def read_stream(stream, stream_name):
nonlocal last_log_time
buffered_content = []
while True:
line = await stream.readline()
if not line:
break
line_text = line.decode().strip()
if not line_text:
continue
# Fix truncated "ulling manifest" text if present
if "ulling manifest" in line_text and not line_text.startswith("Pulling"):
line_text = "P" + line_text
current_time = time.time()
# Always log important messages like errors
if "error" in line_text.lower() or "fail" in line_text.lower():
logger.info(f"Ollama pull ({stream_name}): {line_text}")
continue
# Buffer regular progress messages
buffered_content.append(line_text)
# Log throttled progress updates
if current_time - last_log_time >= LOG_INTERVAL:
if buffered_content:
logger.info(f"Ollama pull progress: {buffered_content[-1]}")
buffered_content = []
last_log_time = current_time
# Read from both stdout and stderr concurrently
await asyncio.gather(read_stream(process.stdout, "stdout"), read_stream(process.stderr, "stderr"))
# Wait for process to complete
exit_code = await process.wait()
if exit_code == 0:
logger.info(f"Successfully pulled model {model_name}")
# Give a moment for Ollama to finalize the model
await asyncio.sleep(1)
# Verify model is available
if await is_model_available(model_name):
logger.info(f"Verified model {model_name} is now available")
return True
else:
logger.error(f"Model pull completed but {model_name} not found in list")
return False
else:
logger.error(f"Failed to pull model {model_name} with exit code {exit_code}")
return False
except Exception as e:
logger.error(f"Error pulling model {model_name}: {e}")
return False
async def ensure_model_ready(model_name: str) -> bool:
"""Ensure Ollama server is running and the model is available."""
# Step 1: Make sure Ollama server is running
if not await start_ollama_server():
logger.error("Failed to start Ollama server")
return False
# Step 2: Check if model is already available
if await is_model_available(model_name):
logger.info(f"Model {model_name} is already available")
return True
# Step 3: Pull the model if not available
logger.info(f"Model {model_name} not found, downloading...")
if await pull_model(model_name):
logger.info(f"Model {model_name} successfully pulled and ready")
return True
else:
logger.error(f"Failed to pull model {model_name}")
return False
async def stop_ollama_server() -> bool:
"""Stop the Ollama server."""
try:
# Find the ollama executable (just to check if it's installed)
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
return False
logger.info("Stopping Ollama server...")
# Different process termination based on platform
if os.name == "nt": # Windows
# Windows uses taskkill to terminate processes
process = await asyncio.create_subprocess_exec(
"taskkill", "/F", "/IM", "ollama.exe", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
else: # Unix/Linux/Mac
# Use pkill to terminate all Ollama processes
process = await asyncio.create_subprocess_exec("pkill", "-f", "ollama", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
_, stderr = await process.communicate()
# Check if the Ollama server is still running
if await is_ollama_running():
logger.error(f"Failed to stop Ollama server: {stderr.decode() if stderr else 'unknown error'}")
logger.info("Attempting stronger termination...")
# Try one more time with stronger termination if it's still running
if os.name == "nt": # Windows
process = await asyncio.create_subprocess_exec(
"taskkill",
"/F",
"/IM",
"ollama.exe",
"/T", # /T terminates child processes as well
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
else: # Unix/Linux/Mac
process = await asyncio.create_subprocess_exec(
"pkill",
"-9",
"-f",
"ollama", # SIGKILL for force termination
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await process.communicate()
if await is_ollama_running():
logger.error("Failed to forcefully stop Ollama server")
return False
# Wait a moment to ensure processes are actually terminated
await asyncio.sleep(1)
# Verify the server is no longer running
if not await is_ollama_running():
logger.info("Ollama server stopped successfully")
return True
else:
logger.error("Failed to stop Ollama server: still running after termination attempts")
return False
except Exception as e:
logger.error(f"Error stopping Ollama server: {str(e)}")
cmd = "taskkill /F /IM ollama.exe" if os.name == "nt" else "pkill -f ollama"
logger.error(f"Command attempted: {cmd}")
return False
async def delete_model(model_name: str) -> bool:
"""Delete a model from Ollama.
Args:
model_name: The name of the model to delete
Returns:
bool: True if deletion was successful, False otherwise
"""
try:
# Find the ollama executable
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
return False
logger.info(f"Deleting model {model_name} from Ollama...")
process = await asyncio.create_subprocess_exec(ollama_bin, "rm", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.error(f"Failed to delete model {model_name}: {stderr.decode()}")
return False
logger.info(f"Model {model_name} deleted successfully")
return True
except Exception as e:
logger.error(f"Error deleting model {model_name}: {e}")
return False