File size: 13,952 Bytes
af36381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
# Optimized Devstral Inference Client
# Connects to deployed model without restarting server for each request
# Implements Modal best practices for lowest latency

import modal
import asyncio
import time
from typing import List, Dict, Any, Optional, AsyncGenerator
import json

# Connect to the deployed app
app = modal.App("devstral-inference-client")

# Image with OpenAI client for making requests
client_image = modal.Image.debian_slim(python_version="3.12").pip_install(
    "openai>=1.76.0",
    "aiohttp>=3.9.0",
    "asyncio-throttle>=1.0.0"
)

class DevstralClient:
    """Optimized client for Devstral model with persistent connections and caching"""
    
    def __init__(self, base_url: str, api_key: str):
        self.base_url = base_url
        self.api_key = api_key
        self._session = None
        self._response_cache = {}
        self._conversation_cache = {}
        
    async def __aenter__(self):
        """Async context manager entry - create persistent HTTP session"""
        import aiohttp
        connector = aiohttp.TCPConnector(
            limit=100,  # Connection pool size
            keepalive_timeout=300,  # Keep connections alive
            enable_cleanup_closed=True
        )
        self._session = aiohttp.ClientSession(
            connector=connector,
            timeout=aiohttp.ClientTimeout(total=120)
        )
        return self
        
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Clean up session on exit"""
        if self._session:
            await self._session.close()

    def _get_cache_key(self, messages: List[Dict], **kwargs) -> str:
        """Generate cache key for deterministic requests"""
        key_data = {
            "messages": json.dumps(messages, sort_keys=True),
            "temperature": kwargs.get("temperature", 0.1),
            "max_tokens": kwargs.get("max_tokens", 500),
            "top_p": kwargs.get("top_p", 0.95)
        }
        return hash(json.dumps(key_data, sort_keys=True))

    async def generate_response(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        temperature: float = 0.1,
        max_tokens: int = 10000,
        stream: bool = False,
        use_cache: bool = True
    ) -> str:
        """Generate response from Devstral model with optimizations"""
        
        # Build messages
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        # Check cache for deterministic requests
        if use_cache and temperature == 0.0:
            cache_key = self._get_cache_key(messages, temperature=temperature, max_tokens=max_tokens)
            if cache_key in self._response_cache:
                print("πŸ“Ž Cache hit - returning cached response")
                return self._response_cache[cache_key]
        
        # Prepare request payload
        payload = {
            "model": "mistralai/Devstral-Small-2505",
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            # "top_p": 0.95,
            "stream": stream
        }
        
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        start_time = time.perf_counter()
        
        if stream:
            return await self._stream_response(payload, headers)
        else:
            return await self._complete_response(payload, headers, use_cache, start_time)

    async def _complete_response(self, payload: Dict, headers: Dict, use_cache: bool, start_time: float) -> str:
        """Handle complete (non-streaming) response"""
        async with self._session.post(
            f"{self.base_url}/v1/chat/completions",
            json=payload,
            headers=headers
        ) as response:
            if response.status != 200:
                error_text = await response.text()
                raise Exception(f"API Error {response.status}: {error_text}")
            
            result = await response.json()
            latency = (time.perf_counter() - start_time) * 1000
            
            generated_text = result["choices"][0]["message"]["content"]
            
            # Cache deterministic responses
            if use_cache and payload["temperature"] == 0.0:
                cache_key = self._get_cache_key(payload["messages"], **payload)
                self._response_cache[cache_key] = generated_text
            
            print(f"⚑ Response generated in {latency:.2f}ms")
            return generated_text

    async def _stream_response(self, payload: Dict, headers: Dict) -> AsyncGenerator[str, None]:
        """Handle streaming response"""
        payload["stream"] = True
        
        async with self._session.post(
            f"{self.base_url}/v1/chat/completions",
            json=payload,
            headers=headers
        ) as response:
            if response.status != 200:
                error_text = await response.text()
                raise Exception(f"API Error {response.status}: {error_text}")
            
            buffer = ""
            async for chunk in response.content.iter_chunks():
                if chunk[0]:
                    buffer += chunk[0].decode()
                    while "\n" in buffer:
                        line, buffer = buffer.split("\n", 1)
                        if line.startswith("data: "):
                            data = line[6:]
                            if data == "[DONE]":
                                return
                            try:
                                json_data = json.loads(data)
                                if "choices" in json_data and json_data["choices"]:
                                    delta = json_data["choices"][0].get("delta", {})
                                    if "content" in delta:
                                        yield delta["content"]
                            except json.JSONDecodeError:
                                continue

    async def batch_generate(
        self,
        prompts: List[str],
        system_prompt: Optional[str] = None,
        temperature: float = 0.1,
        max_tokens: int = 500,
        max_concurrent: int = 5
    ) -> List[str]:
        """Generate responses for multiple prompts with concurrency control"""
        from asyncio_throttle import Throttler
        
        # Throttle requests to avoid overwhelming the server
        throttler = Throttler(rate_limit=max_concurrent, period=1.0)
        
        async def generate_single(prompt: str) -> str:
            async with throttler:
                return await self.generate_response(
                    prompt=prompt,
                    system_prompt=system_prompt,
                    temperature=temperature,
                    max_tokens=max_tokens
                )
        
        # Execute all requests concurrently
        tasks = [generate_single(prompt) for prompt in prompts]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Handle any exceptions
        processed_results = []
        for result in results:
            if isinstance(result, Exception):
                processed_results.append(f"Error: {str(result)}")
            else:
                processed_results.append(result)
                
        return processed_results

@app.function(
    image=client_image,
    timeout=600,  # 10 minutes
)
async def run_devstral_inference(
    base_url: str,
    api_key: str,
    prompts: List[str],
    system_prompt: Optional[str] = None,
    mode: str = "single"  # "single", "batch", "stream"
):
    """Main function to run optimized Devstral inference"""
    
    async with DevstralClient(base_url, api_key) as client:
        
        if mode == "single":
            # Single prompt inference
            if len(prompts) > 0:
                response = await client.generate_response(
                    prompt=prompts[0],
                    system_prompt=system_prompt,
                    temperature=0.1,
                    max_tokens=10000
                )
                return {"response": response}
                
        elif mode == "batch":
            # Batch inference for multiple prompts
            responses = await client.batch_generate(
                prompts=prompts,
                system_prompt=system_prompt,
                temperature=0.1,
                max_tokens=10000,
                max_concurrent=5
            )
            return {"responses": responses}
            
        elif mode == "stream":
            # Streaming inference
            if len(prompts) > 0:
                full_response = ""
                async for chunk in client.generate_response(
                    prompt=prompts[0],
                    system_prompt=system_prompt,
                    temperature=0.1,
                    stream=True
                ):
                    full_response += chunk
                    print(chunk, end="", flush=True)
                print()  # New line after streaming
                return {"response": full_response}
    
    return {"error": "No prompts provided"}

# Convenient wrapper functions for different use cases
@app.function(image=client_image)
async def code_generation(prompt: str, base_url: str, api_key: str) -> str:
    """Optimized for code generation tasks"""
    system_prompt = """You are an expert software engineer. Generate clean, efficient, and well-documented code. 
    Focus on best practices, performance, and maintainability. Include brief explanations for complex logic."""
    
    async with DevstralClient(base_url, api_key) as client:
        return await client.generate_response(
            prompt=prompt,
            system_prompt=system_prompt,
            temperature=0.0,  # Deterministic for code
            max_tokens=10000,
            use_cache=True  # Cache code responses
        )

@app.function(image=client_image)
async def chat_response(prompt: str, base_url: str, api_key: str) -> str:
    """Optimized for conversational responses"""
    system_prompt = """You are a helpful, knowledgeable AI assistant. Provide clear, concise, and accurate responses. 
    Be conversational but professional."""
    
    async with DevstralClient(base_url, api_key) as client:
        return await client.generate_response(
            prompt=prompt,
            system_prompt=system_prompt,
            temperature=0.3,  # Slightly creative
            max_tokens=10000
        )

@app.function(image=client_image)
async def document_analysis(prompt: str, base_url: str, api_key: str) -> str:
    """Optimized for document analysis and summarization"""
    system_prompt = """You are an expert document analyst. Provide thorough, structured analysis with key insights, 
    summaries, and actionable recommendations. Use clear formatting and bullet points."""
    
    async with DevstralClient(base_url, api_key) as client:
        return await client.generate_response(
            prompt=prompt,
            system_prompt=system_prompt,
            temperature=0.1,  # Factual and consistent
            max_tokens=800
        )

# Local test client for development
@app.local_entrypoint()
def main(
    base_url: str = "https://abhinav-bhatnagar--devstral-vllm-deployment-serve.modal.run",
    api_key: str = "ak-zMwhIPjqvBj30jbm1DmKqx",
    mode: str = "single"
):
    """Test the optimized Devstral inference client"""
    
    test_prompts = [
        "Write a Python function to calculate the Fibonacci sequence using memoization.",
        "Explain the difference between REST and GraphQL APIs.",
        "What are the key benefits of using Docker containers?",
        "How does machine learning differ from traditional programming?",
        "Write a SQL query to find the top 5 customers by total order value."
    ]
    
    print(f"πŸš€ Testing Devstral inference in {mode} mode...")
    print(f"πŸ“‘ Connecting to: {base_url}")
    
    if mode == "single":
        # Test single inference
        result = run_devstral_inference.remote(
            base_url=base_url,
            api_key=api_key,
            prompts=[test_prompts[0]],
            system_prompt="You are a helpful coding assistant.",
            mode="single"
        )
        print("βœ… Single inference result:")
        print(result["response"])
        
    elif mode == "batch":
        # Test batch inference
        result = run_devstral_inference.remote(
            base_url=base_url,
            api_key=api_key,
            prompts=test_prompts[:3],  # Test with 3 prompts
            system_prompt="You are a knowledgeable AI assistant.",
            mode="batch"
        )
        print("βœ… Batch inference results:")
        for i, response in enumerate(result["responses"]):
            print(f"\nPrompt {i+1}: {test_prompts[i]}")
            print(f"Response: {response}")
            
    elif mode == "specialized":
        # Test specialized functions
        print("\nπŸ“ Testing Code Generation:")
        code_result = code_generation.remote(
            prompt="Create a Python class for a binary search tree with insert, search, and delete methods.",
            base_url=base_url,
            api_key=api_key
        )
        print(code_result)
        
        print("\nπŸ’¬ Testing Chat Response:")
        chat_result = chat_response.remote(
            prompt="What's the best way to learn machine learning for beginners?",
            base_url=base_url,
            api_key=api_key
        )
        print(chat_result)
    
    print("\nπŸŽ‰ Testing completed!")

if __name__ == "__main__":
    # This allows running the client locally for testing
    import sys
    mode = sys.argv[1] if len(sys.argv) > 1 else "single"
    main(mode=mode)