Spaces:
Running
Running
File size: 12,311 Bytes
5301c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 |
"""Ollama adapter."""
import asyncio
import os
import shutil
import subprocess
import time
from typing import Any, Dict, List
import aiohttp
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Default Ollama connection settings
OLLAMA_HOST = "localhost"
OLLAMA_PORT = 11434
OLLAMA_BASE_URL = f"http://{OLLAMA_HOST}:{OLLAMA_PORT}"
class OllamaError(Exception):
"""Base exception for Ollama-related errors."""
pass
class OllamaNotInstalledError(OllamaError):
"""Error raised when Ollama is not installed."""
pass
class OllamaConnectionError(OllamaError):
"""Error raised when connection to Ollama server fails."""
pass
async def is_ollama_running() -> bool:
"""Check if Ollama server is running."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{OLLAMA_BASE_URL}/api/version", timeout=aiohttp.ClientTimeout(total=2)) as response:
return response.status == 200
except Exception:
return False
async def start_ollama_server() -> bool:
"""Start the Ollama server if it's not already running."""
# Check if already running
if await is_ollama_running():
logger.info("Ollama server is already running")
return True
# Find the ollama executable
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
raise OllamaNotInstalledError("Ollama is not installed. Please install from https://ollama.com/download")
logger.info("Starting Ollama server...")
try:
# Start Ollama as a detached process
if os.name == "nt": # Windows
subprocess.Popen([ollama_bin, "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
else: # Unix/Linux/Mac
subprocess.Popen(
[ollama_bin, "serve"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setpgrp, # Run in a new process group
)
# Wait for the server to start
for _ in range(10):
if await is_ollama_running():
logger.info("Ollama server started successfully")
return True
await asyncio.sleep(0.5)
logger.error("Timed out waiting for Ollama server to start")
return False
except Exception as e:
logger.error(f"Failed to start Ollama server: {e}")
return False
async def list_models() -> List[Dict[str, Any]]:
"""List available models in Ollama using the API."""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{OLLAMA_BASE_URL}/api/tags") as response:
if response.status != 200:
logger.error(f"Error listing models: {response.status}")
return []
data = await response.json()
return data.get("models", [])
except Exception as e:
logger.error(f"Error listing models: {e}")
return []
async def is_model_available(model_name: str) -> bool:
"""Check if model is available using the CLI command."""
try:
# Use CLI for more reliable checking
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama binary not found")
return False
process = await asyncio.create_subprocess_exec(ollama_bin, "list", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, _ = await process.communicate()
output = stdout.decode().strip()
# Log what models are available
logger.debug(f"Available models: {output}")
# Look for the model name in the output
model_lines = output.split("\n")
for line in model_lines:
# Check if the line contains the model name
if line.strip() and model_name in line.split()[0]:
logger.info(f"Found model {model_name}")
return True
logger.info(f"Model {model_name} not found")
return False
except Exception as e:
logger.error(f"Error checking if model is available: {e}")
return False
async def pull_model(model_name: str) -> bool:
"""Pull a model using the Ollama CLI.
This is more reliable than the API for large downloads.
"""
# Use the Ollama CLI directly for more reliable downloads
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama binary not found")
return False
try:
# Set logging interval
LOG_INTERVAL = 10 # Only log every 10 seconds
logger.info(f"Pulling model {model_name}... (progress updates every {LOG_INTERVAL} seconds)")
# Create the subprocess
process = await asyncio.create_subprocess_exec(ollama_bin, "pull", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
# Track last progress output time to throttle logging
last_log_time = 0
# Define functions to read from both stdout and stderr
async def read_stream(stream, stream_name):
nonlocal last_log_time
buffered_content = []
while True:
line = await stream.readline()
if not line:
break
line_text = line.decode().strip()
if not line_text:
continue
# Fix truncated "ulling manifest" text if present
if "ulling manifest" in line_text and not line_text.startswith("Pulling"):
line_text = "P" + line_text
current_time = time.time()
# Always log important messages like errors
if "error" in line_text.lower() or "fail" in line_text.lower():
logger.info(f"Ollama pull ({stream_name}): {line_text}")
continue
# Buffer regular progress messages
buffered_content.append(line_text)
# Log throttled progress updates
if current_time - last_log_time >= LOG_INTERVAL:
if buffered_content:
logger.info(f"Ollama pull progress: {buffered_content[-1]}")
buffered_content = []
last_log_time = current_time
# Read from both stdout and stderr concurrently
await asyncio.gather(read_stream(process.stdout, "stdout"), read_stream(process.stderr, "stderr"))
# Wait for process to complete
exit_code = await process.wait()
if exit_code == 0:
logger.info(f"Successfully pulled model {model_name}")
# Give a moment for Ollama to finalize the model
await asyncio.sleep(1)
# Verify model is available
if await is_model_available(model_name):
logger.info(f"Verified model {model_name} is now available")
return True
else:
logger.error(f"Model pull completed but {model_name} not found in list")
return False
else:
logger.error(f"Failed to pull model {model_name} with exit code {exit_code}")
return False
except Exception as e:
logger.error(f"Error pulling model {model_name}: {e}")
return False
async def ensure_model_ready(model_name: str) -> bool:
"""Ensure Ollama server is running and the model is available."""
# Step 1: Make sure Ollama server is running
if not await start_ollama_server():
logger.error("Failed to start Ollama server")
return False
# Step 2: Check if model is already available
if await is_model_available(model_name):
logger.info(f"Model {model_name} is already available")
return True
# Step 3: Pull the model if not available
logger.info(f"Model {model_name} not found, downloading...")
if await pull_model(model_name):
logger.info(f"Model {model_name} successfully pulled and ready")
return True
else:
logger.error(f"Failed to pull model {model_name}")
return False
async def stop_ollama_server() -> bool:
"""Stop the Ollama server."""
try:
# Find the ollama executable (just to check if it's installed)
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
return False
logger.info("Stopping Ollama server...")
# Different process termination based on platform
if os.name == "nt": # Windows
# Windows uses taskkill to terminate processes
process = await asyncio.create_subprocess_exec(
"taskkill", "/F", "/IM", "ollama.exe", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
else: # Unix/Linux/Mac
# Use pkill to terminate all Ollama processes
process = await asyncio.create_subprocess_exec("pkill", "-f", "ollama", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
_, stderr = await process.communicate()
# Check if the Ollama server is still running
if await is_ollama_running():
logger.error(f"Failed to stop Ollama server: {stderr.decode() if stderr else 'unknown error'}")
logger.info("Attempting stronger termination...")
# Try one more time with stronger termination if it's still running
if os.name == "nt": # Windows
process = await asyncio.create_subprocess_exec(
"taskkill",
"/F",
"/IM",
"ollama.exe",
"/T", # /T terminates child processes as well
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
else: # Unix/Linux/Mac
process = await asyncio.create_subprocess_exec(
"pkill",
"-9",
"-f",
"ollama", # SIGKILL for force termination
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await process.communicate()
if await is_ollama_running():
logger.error("Failed to forcefully stop Ollama server")
return False
# Wait a moment to ensure processes are actually terminated
await asyncio.sleep(1)
# Verify the server is no longer running
if not await is_ollama_running():
logger.info("Ollama server stopped successfully")
return True
else:
logger.error("Failed to stop Ollama server: still running after termination attempts")
return False
except Exception as e:
logger.error(f"Error stopping Ollama server: {str(e)}")
cmd = "taskkill /F /IM ollama.exe" if os.name == "nt" else "pkill -f ollama"
logger.error(f"Command attempted: {cmd}")
return False
async def delete_model(model_name: str) -> bool:
"""Delete a model from Ollama.
Args:
model_name: The name of the model to delete
Returns:
bool: True if deletion was successful, False otherwise
"""
try:
# Find the ollama executable
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
return False
logger.info(f"Deleting model {model_name} from Ollama...")
process = await asyncio.create_subprocess_exec(ollama_bin, "rm", model_name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.error(f"Failed to delete model {model_name}: {stderr.decode()}")
return False
logger.info(f"Model {model_name} deleted successfully")
return True
except Exception as e:
logger.error(f"Error deleting model {model_name}: {e}")
return False
|