|
|
|
|
|
|
|
|
|
import modal |
|
import asyncio |
|
import time |
|
from typing import List, Dict, Any, Optional, AsyncGenerator |
|
import json |
|
|
|
|
|
app = modal.App("devstral-inference-client") |
|
|
|
|
|
client_image = modal.Image.debian_slim(python_version="3.12").pip_install( |
|
"openai>=1.76.0", |
|
"aiohttp>=3.9.0", |
|
"asyncio-throttle>=1.0.0" |
|
) |
|
|
|
class DevstralClient: |
|
"""Optimized client for Devstral model with persistent connections and caching""" |
|
|
|
def __init__(self, base_url: str, api_key: str): |
|
self.base_url = base_url |
|
self.api_key = api_key |
|
self._session = None |
|
self._response_cache = {} |
|
self._conversation_cache = {} |
|
|
|
async def __aenter__(self): |
|
"""Async context manager entry - create persistent HTTP session""" |
|
import aiohttp |
|
connector = aiohttp.TCPConnector( |
|
limit=100, |
|
keepalive_timeout=300, |
|
enable_cleanup_closed=True |
|
) |
|
self._session = aiohttp.ClientSession( |
|
connector=connector, |
|
timeout=aiohttp.ClientTimeout(total=120) |
|
) |
|
return self |
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
"""Clean up session on exit""" |
|
if self._session: |
|
await self._session.close() |
|
|
|
def _get_cache_key(self, messages: List[Dict], **kwargs) -> str: |
|
"""Generate cache key for deterministic requests""" |
|
key_data = { |
|
"messages": json.dumps(messages, sort_keys=True), |
|
"temperature": kwargs.get("temperature", 0.1), |
|
"max_tokens": kwargs.get("max_tokens", 500), |
|
"top_p": kwargs.get("top_p", 0.95) |
|
} |
|
return hash(json.dumps(key_data, sort_keys=True)) |
|
|
|
async def generate_response( |
|
self, |
|
prompt: str, |
|
system_prompt: Optional[str] = None, |
|
temperature: float = 0.1, |
|
max_tokens: int = 10000, |
|
stream: bool = False, |
|
use_cache: bool = True |
|
) -> str: |
|
"""Generate response from Devstral model with optimizations""" |
|
|
|
|
|
messages = [] |
|
if system_prompt: |
|
messages.append({"role": "system", "content": system_prompt}) |
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
if use_cache and temperature == 0.0: |
|
cache_key = self._get_cache_key(messages, temperature=temperature, max_tokens=max_tokens) |
|
if cache_key in self._response_cache: |
|
print("π Cache hit - returning cached response") |
|
return self._response_cache[cache_key] |
|
|
|
|
|
payload = { |
|
"model": "mistralai/Devstral-Small-2505", |
|
"messages": messages, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens, |
|
|
|
"stream": stream |
|
} |
|
|
|
headers = { |
|
"Authorization": f"Bearer {self.api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
start_time = time.perf_counter() |
|
|
|
if stream: |
|
return await self._stream_response(payload, headers) |
|
else: |
|
return await self._complete_response(payload, headers, use_cache, start_time) |
|
|
|
async def _complete_response(self, payload: Dict, headers: Dict, use_cache: bool, start_time: float) -> str: |
|
"""Handle complete (non-streaming) response""" |
|
async with self._session.post( |
|
f"{self.base_url}/v1/chat/completions", |
|
json=payload, |
|
headers=headers |
|
) as response: |
|
if response.status != 200: |
|
error_text = await response.text() |
|
raise Exception(f"API Error {response.status}: {error_text}") |
|
|
|
result = await response.json() |
|
latency = (time.perf_counter() - start_time) * 1000 |
|
|
|
generated_text = result["choices"][0]["message"]["content"] |
|
|
|
|
|
if use_cache and payload["temperature"] == 0.0: |
|
cache_key = self._get_cache_key(payload["messages"], **payload) |
|
self._response_cache[cache_key] = generated_text |
|
|
|
print(f"β‘ Response generated in {latency:.2f}ms") |
|
return generated_text |
|
|
|
async def _stream_response(self, payload: Dict, headers: Dict) -> AsyncGenerator[str, None]: |
|
"""Handle streaming response""" |
|
payload["stream"] = True |
|
|
|
async with self._session.post( |
|
f"{self.base_url}/v1/chat/completions", |
|
json=payload, |
|
headers=headers |
|
) as response: |
|
if response.status != 200: |
|
error_text = await response.text() |
|
raise Exception(f"API Error {response.status}: {error_text}") |
|
|
|
buffer = "" |
|
async for chunk in response.content.iter_chunks(): |
|
if chunk[0]: |
|
buffer += chunk[0].decode() |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
if line.startswith("data: "): |
|
data = line[6:] |
|
if data == "[DONE]": |
|
return |
|
try: |
|
json_data = json.loads(data) |
|
if "choices" in json_data and json_data["choices"]: |
|
delta = json_data["choices"][0].get("delta", {}) |
|
if "content" in delta: |
|
yield delta["content"] |
|
except json.JSONDecodeError: |
|
continue |
|
|
|
async def batch_generate( |
|
self, |
|
prompts: List[str], |
|
system_prompt: Optional[str] = None, |
|
temperature: float = 0.1, |
|
max_tokens: int = 500, |
|
max_concurrent: int = 5 |
|
) -> List[str]: |
|
"""Generate responses for multiple prompts with concurrency control""" |
|
from asyncio_throttle import Throttler |
|
|
|
|
|
throttler = Throttler(rate_limit=max_concurrent, period=1.0) |
|
|
|
async def generate_single(prompt: str) -> str: |
|
async with throttler: |
|
return await self.generate_response( |
|
prompt=prompt, |
|
system_prompt=system_prompt, |
|
temperature=temperature, |
|
max_tokens=max_tokens |
|
) |
|
|
|
|
|
tasks = [generate_single(prompt) for prompt in prompts] |
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
processed_results = [] |
|
for result in results: |
|
if isinstance(result, Exception): |
|
processed_results.append(f"Error: {str(result)}") |
|
else: |
|
processed_results.append(result) |
|
|
|
return processed_results |
|
|
|
@app.function( |
|
image=client_image, |
|
timeout=600, |
|
) |
|
async def run_devstral_inference( |
|
base_url: str, |
|
api_key: str, |
|
prompts: List[str], |
|
system_prompt: Optional[str] = None, |
|
mode: str = "single" |
|
): |
|
"""Main function to run optimized Devstral inference""" |
|
|
|
async with DevstralClient(base_url, api_key) as client: |
|
|
|
if mode == "single": |
|
|
|
if len(prompts) > 0: |
|
response = await client.generate_response( |
|
prompt=prompts[0], |
|
system_prompt=system_prompt, |
|
temperature=0.1, |
|
max_tokens=10000 |
|
) |
|
return {"response": response} |
|
|
|
elif mode == "batch": |
|
|
|
responses = await client.batch_generate( |
|
prompts=prompts, |
|
system_prompt=system_prompt, |
|
temperature=0.1, |
|
max_tokens=10000, |
|
max_concurrent=5 |
|
) |
|
return {"responses": responses} |
|
|
|
elif mode == "stream": |
|
|
|
if len(prompts) > 0: |
|
full_response = "" |
|
async for chunk in client.generate_response( |
|
prompt=prompts[0], |
|
system_prompt=system_prompt, |
|
temperature=0.1, |
|
stream=True |
|
): |
|
full_response += chunk |
|
print(chunk, end="", flush=True) |
|
print() |
|
return {"response": full_response} |
|
|
|
return {"error": "No prompts provided"} |
|
|
|
|
|
@app.function(image=client_image) |
|
async def code_generation(prompt: str, base_url: str, api_key: str) -> str: |
|
"""Optimized for code generation tasks""" |
|
system_prompt = """You are an expert software engineer. Generate clean, efficient, and well-documented code. |
|
Focus on best practices, performance, and maintainability. Include brief explanations for complex logic.""" |
|
|
|
async with DevstralClient(base_url, api_key) as client: |
|
return await client.generate_response( |
|
prompt=prompt, |
|
system_prompt=system_prompt, |
|
temperature=0.0, |
|
max_tokens=10000, |
|
use_cache=True |
|
) |
|
|
|
@app.function(image=client_image) |
|
async def chat_response(prompt: str, base_url: str, api_key: str) -> str: |
|
"""Optimized for conversational responses""" |
|
system_prompt = """You are a helpful, knowledgeable AI assistant. Provide clear, concise, and accurate responses. |
|
Be conversational but professional.""" |
|
|
|
async with DevstralClient(base_url, api_key) as client: |
|
return await client.generate_response( |
|
prompt=prompt, |
|
system_prompt=system_prompt, |
|
temperature=0.3, |
|
max_tokens=10000 |
|
) |
|
|
|
@app.function(image=client_image) |
|
async def document_analysis(prompt: str, base_url: str, api_key: str) -> str: |
|
"""Optimized for document analysis and summarization""" |
|
system_prompt = """You are an expert document analyst. Provide thorough, structured analysis with key insights, |
|
summaries, and actionable recommendations. Use clear formatting and bullet points.""" |
|
|
|
async with DevstralClient(base_url, api_key) as client: |
|
return await client.generate_response( |
|
prompt=prompt, |
|
system_prompt=system_prompt, |
|
temperature=0.1, |
|
max_tokens=800 |
|
) |
|
|
|
|
|
@app.local_entrypoint() |
|
def main( |
|
base_url: str = "https://abhinav-bhatnagar--devstral-vllm-deployment-serve.modal.run", |
|
api_key: str = "ak-zMwhIPjqvBj30jbm1DmKqx", |
|
mode: str = "single" |
|
): |
|
"""Test the optimized Devstral inference client""" |
|
|
|
test_prompts = [ |
|
"Write a Python function to calculate the Fibonacci sequence using memoization.", |
|
"Explain the difference between REST and GraphQL APIs.", |
|
"What are the key benefits of using Docker containers?", |
|
"How does machine learning differ from traditional programming?", |
|
"Write a SQL query to find the top 5 customers by total order value." |
|
] |
|
|
|
print(f"π Testing Devstral inference in {mode} mode...") |
|
print(f"π‘ Connecting to: {base_url}") |
|
|
|
if mode == "single": |
|
|
|
result = run_devstral_inference.remote( |
|
base_url=base_url, |
|
api_key=api_key, |
|
prompts=[test_prompts[0]], |
|
system_prompt="You are a helpful coding assistant.", |
|
mode="single" |
|
) |
|
print("β
Single inference result:") |
|
print(result["response"]) |
|
|
|
elif mode == "batch": |
|
|
|
result = run_devstral_inference.remote( |
|
base_url=base_url, |
|
api_key=api_key, |
|
prompts=test_prompts[:3], |
|
system_prompt="You are a knowledgeable AI assistant.", |
|
mode="batch" |
|
) |
|
print("β
Batch inference results:") |
|
for i, response in enumerate(result["responses"]): |
|
print(f"\nPrompt {i+1}: {test_prompts[i]}") |
|
print(f"Response: {response}") |
|
|
|
elif mode == "specialized": |
|
|
|
print("\nπ Testing Code Generation:") |
|
code_result = code_generation.remote( |
|
prompt="Create a Python class for a binary search tree with insert, search, and delete methods.", |
|
base_url=base_url, |
|
api_key=api_key |
|
) |
|
print(code_result) |
|
|
|
print("\n㪠Testing Chat Response:") |
|
chat_result = chat_response.remote( |
|
prompt="What's the best way to learn machine learning for beginners?", |
|
base_url=base_url, |
|
api_key=api_key |
|
) |
|
print(chat_result) |
|
|
|
print("\nπ Testing completed!") |
|
|
|
if __name__ == "__main__": |
|
|
|
import sys |
|
mode = sys.argv[1] if len(sys.argv) > 1 else "single" |
|
main(mode=mode) |
|
|