File size: 13,952 Bytes
af36381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
# Optimized Devstral Inference Client
# Connects to deployed model without restarting server for each request
# Implements Modal best practices for lowest latency
import modal
import asyncio
import time
from typing import List, Dict, Any, Optional, AsyncGenerator
import json
# Connect to the deployed app
app = modal.App("devstral-inference-client")
# Image with OpenAI client for making requests
client_image = modal.Image.debian_slim(python_version="3.12").pip_install(
"openai>=1.76.0",
"aiohttp>=3.9.0",
"asyncio-throttle>=1.0.0"
)
class DevstralClient:
"""Optimized client for Devstral model with persistent connections and caching"""
def __init__(self, base_url: str, api_key: str):
self.base_url = base_url
self.api_key = api_key
self._session = None
self._response_cache = {}
self._conversation_cache = {}
async def __aenter__(self):
"""Async context manager entry - create persistent HTTP session"""
import aiohttp
connector = aiohttp.TCPConnector(
limit=100, # Connection pool size
keepalive_timeout=300, # Keep connections alive
enable_cleanup_closed=True
)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=120)
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Clean up session on exit"""
if self._session:
await self._session.close()
def _get_cache_key(self, messages: List[Dict], **kwargs) -> str:
"""Generate cache key for deterministic requests"""
key_data = {
"messages": json.dumps(messages, sort_keys=True),
"temperature": kwargs.get("temperature", 0.1),
"max_tokens": kwargs.get("max_tokens", 500),
"top_p": kwargs.get("top_p", 0.95)
}
return hash(json.dumps(key_data, sort_keys=True))
async def generate_response(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: float = 0.1,
max_tokens: int = 10000,
stream: bool = False,
use_cache: bool = True
) -> str:
"""Generate response from Devstral model with optimizations"""
# Build messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Check cache for deterministic requests
if use_cache and temperature == 0.0:
cache_key = self._get_cache_key(messages, temperature=temperature, max_tokens=max_tokens)
if cache_key in self._response_cache:
print("π Cache hit - returning cached response")
return self._response_cache[cache_key]
# Prepare request payload
payload = {
"model": "mistralai/Devstral-Small-2505",
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
# "top_p": 0.95,
"stream": stream
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
start_time = time.perf_counter()
if stream:
return await self._stream_response(payload, headers)
else:
return await self._complete_response(payload, headers, use_cache, start_time)
async def _complete_response(self, payload: Dict, headers: Dict, use_cache: bool, start_time: float) -> str:
"""Handle complete (non-streaming) response"""
async with self._session.post(
f"{self.base_url}/v1/chat/completions",
json=payload,
headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"API Error {response.status}: {error_text}")
result = await response.json()
latency = (time.perf_counter() - start_time) * 1000
generated_text = result["choices"][0]["message"]["content"]
# Cache deterministic responses
if use_cache and payload["temperature"] == 0.0:
cache_key = self._get_cache_key(payload["messages"], **payload)
self._response_cache[cache_key] = generated_text
print(f"β‘ Response generated in {latency:.2f}ms")
return generated_text
async def _stream_response(self, payload: Dict, headers: Dict) -> AsyncGenerator[str, None]:
"""Handle streaming response"""
payload["stream"] = True
async with self._session.post(
f"{self.base_url}/v1/chat/completions",
json=payload,
headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"API Error {response.status}: {error_text}")
buffer = ""
async for chunk in response.content.iter_chunks():
if chunk[0]:
buffer += chunk[0].decode()
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
return
try:
json_data = json.loads(data)
if "choices" in json_data and json_data["choices"]:
delta = json_data["choices"][0].get("delta", {})
if "content" in delta:
yield delta["content"]
except json.JSONDecodeError:
continue
async def batch_generate(
self,
prompts: List[str],
system_prompt: Optional[str] = None,
temperature: float = 0.1,
max_tokens: int = 500,
max_concurrent: int = 5
) -> List[str]:
"""Generate responses for multiple prompts with concurrency control"""
from asyncio_throttle import Throttler
# Throttle requests to avoid overwhelming the server
throttler = Throttler(rate_limit=max_concurrent, period=1.0)
async def generate_single(prompt: str) -> str:
async with throttler:
return await self.generate_response(
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens
)
# Execute all requests concurrently
tasks = [generate_single(prompt) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions
processed_results = []
for result in results:
if isinstance(result, Exception):
processed_results.append(f"Error: {str(result)}")
else:
processed_results.append(result)
return processed_results
@app.function(
image=client_image,
timeout=600, # 10 minutes
)
async def run_devstral_inference(
base_url: str,
api_key: str,
prompts: List[str],
system_prompt: Optional[str] = None,
mode: str = "single" # "single", "batch", "stream"
):
"""Main function to run optimized Devstral inference"""
async with DevstralClient(base_url, api_key) as client:
if mode == "single":
# Single prompt inference
if len(prompts) > 0:
response = await client.generate_response(
prompt=prompts[0],
system_prompt=system_prompt,
temperature=0.1,
max_tokens=10000
)
return {"response": response}
elif mode == "batch":
# Batch inference for multiple prompts
responses = await client.batch_generate(
prompts=prompts,
system_prompt=system_prompt,
temperature=0.1,
max_tokens=10000,
max_concurrent=5
)
return {"responses": responses}
elif mode == "stream":
# Streaming inference
if len(prompts) > 0:
full_response = ""
async for chunk in client.generate_response(
prompt=prompts[0],
system_prompt=system_prompt,
temperature=0.1,
stream=True
):
full_response += chunk
print(chunk, end="", flush=True)
print() # New line after streaming
return {"response": full_response}
return {"error": "No prompts provided"}
# Convenient wrapper functions for different use cases
@app.function(image=client_image)
async def code_generation(prompt: str, base_url: str, api_key: str) -> str:
"""Optimized for code generation tasks"""
system_prompt = """You are an expert software engineer. Generate clean, efficient, and well-documented code.
Focus on best practices, performance, and maintainability. Include brief explanations for complex logic."""
async with DevstralClient(base_url, api_key) as client:
return await client.generate_response(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.0, # Deterministic for code
max_tokens=10000,
use_cache=True # Cache code responses
)
@app.function(image=client_image)
async def chat_response(prompt: str, base_url: str, api_key: str) -> str:
"""Optimized for conversational responses"""
system_prompt = """You are a helpful, knowledgeable AI assistant. Provide clear, concise, and accurate responses.
Be conversational but professional."""
async with DevstralClient(base_url, api_key) as client:
return await client.generate_response(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.3, # Slightly creative
max_tokens=10000
)
@app.function(image=client_image)
async def document_analysis(prompt: str, base_url: str, api_key: str) -> str:
"""Optimized for document analysis and summarization"""
system_prompt = """You are an expert document analyst. Provide thorough, structured analysis with key insights,
summaries, and actionable recommendations. Use clear formatting and bullet points."""
async with DevstralClient(base_url, api_key) as client:
return await client.generate_response(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.1, # Factual and consistent
max_tokens=800
)
# Local test client for development
@app.local_entrypoint()
def main(
base_url: str = "https://abhinav-bhatnagar--devstral-vllm-deployment-serve.modal.run",
api_key: str = "ak-zMwhIPjqvBj30jbm1DmKqx",
mode: str = "single"
):
"""Test the optimized Devstral inference client"""
test_prompts = [
"Write a Python function to calculate the Fibonacci sequence using memoization.",
"Explain the difference between REST and GraphQL APIs.",
"What are the key benefits of using Docker containers?",
"How does machine learning differ from traditional programming?",
"Write a SQL query to find the top 5 customers by total order value."
]
print(f"π Testing Devstral inference in {mode} mode...")
print(f"π‘ Connecting to: {base_url}")
if mode == "single":
# Test single inference
result = run_devstral_inference.remote(
base_url=base_url,
api_key=api_key,
prompts=[test_prompts[0]],
system_prompt="You are a helpful coding assistant.",
mode="single"
)
print("β
Single inference result:")
print(result["response"])
elif mode == "batch":
# Test batch inference
result = run_devstral_inference.remote(
base_url=base_url,
api_key=api_key,
prompts=test_prompts[:3], # Test with 3 prompts
system_prompt="You are a knowledgeable AI assistant.",
mode="batch"
)
print("β
Batch inference results:")
for i, response in enumerate(result["responses"]):
print(f"\nPrompt {i+1}: {test_prompts[i]}")
print(f"Response: {response}")
elif mode == "specialized":
# Test specialized functions
print("\nπ Testing Code Generation:")
code_result = code_generation.remote(
prompt="Create a Python class for a binary search tree with insert, search, and delete methods.",
base_url=base_url,
api_key=api_key
)
print(code_result)
print("\n㪠Testing Chat Response:")
chat_result = chat_response.remote(
prompt="What's the best way to learn machine learning for beginners?",
base_url=base_url,
api_key=api_key
)
print(chat_result)
print("\nπ Testing completed!")
if __name__ == "__main__":
# This allows running the client locally for testing
import sys
mode = sys.argv[1] if len(sys.argv) > 1 else "single"
main(mode=mode)
|