DeepMostInnovations commited on
Commit
0dc313a
Β·
verified Β·
1 Parent(s): 4834e38

Create opensource_inference.py

Browse files
Files changed (1) hide show
  1. opensource_inference.py +522 -0
opensource_inference.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+ from typing import List, Dict
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModel
9
+ )
10
+ from stable_baselines3 import PPO
11
+ from llama_cpp import Llama
12
+ import logging
13
+
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+
21
+ class SalesConversionPredictor:
22
+ """Sales conversion prediction class using Hugging Face models and llama.cpp"""
23
+
24
+ def __init__(self,
25
+ model_path: str,
26
+ embedding_model_name: str = "BAAI/bge-large-en-v1.5",
27
+ llm_gguf_path: str = "path/to/your/llama-3.2-1b-instruct.gguf",
28
+ use_gpu: bool = True,
29
+ n_gpu_layers: int = -1, # -1 for all layers on GPU
30
+ n_ctx: int = 2048,
31
+ use_mini_embeddings: bool = True): # Context window size
32
+ """Initialize with Hugging Face embeddings and llama.cpp LLM"""
33
+
34
+ # Set device for embeddings
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
36
+ logger.info(f"Using device: {self.device}")
37
+
38
+ # Initialize embedding model (BAAI/bge-large-en-v1.5)
39
+ logger.info(f"Loading embedding model: {embedding_model_name}")
40
+ self.embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
41
+ self.embedding_model = AutoModel.from_pretrained(embedding_model_name).to(self.device)
42
+
43
+ # Check if model was trained with mini embeddings
44
+ self.use_mini_embeddings = use_mini_embeddings
45
+ self.embedding_dim = 1024 # BGE-large outputs 1024 dimensions
46
+
47
+ # Initialize LLM model using llama-cpp
48
+ logger.info(f"Loading LLM model from GGUF: {llm_gguf_path}")
49
+ self.llm = Llama.from_pretrained(
50
+ repo_id=llm_gguf_path,
51
+ filename="*Q4_K_M.gguf",
52
+ n_gpu_layers=n_gpu_layers if use_gpu else 0,
53
+ n_ctx=n_ctx,
54
+ verbose=False,
55
+ use_mlock=True, # Keep model in RAM
56
+ n_threads=None # Use all available threads
57
+ )
58
+
59
+ # Load the trained PPO model (force CPU for PPO as recommended)
60
+ ppo_device = "cpu"
61
+ logger.info(f"Loading PPO model on {ppo_device}")
62
+ self.ppo_model = PPO.load(model_path, device=ppo_device)
63
+
64
+ # Store conversation states
65
+ self.conversation_states = {}
66
+
67
+ def _normalize_history_format(self, history: List[Dict[str, str]]) -> List[Dict[str, str]]:
68
+ """Normalize history format to ensure consistency"""
69
+ normalized_history = []
70
+
71
+ for msg in history:
72
+ # Extract role/speaker
73
+ role = msg.get('role', msg.get('speaker', ''))
74
+
75
+ # Extract content/message
76
+ content = msg.get('content', msg.get('message', ''))
77
+
78
+ # Map role to expected format for the model
79
+ if role in ['user', 'customer']:
80
+ speaker = 'user'
81
+ elif role in ['assistant', 'sales_rep']:
82
+ speaker = 'sales_rep'
83
+ else:
84
+ speaker = role # Keep as is
85
+
86
+ normalized_history.append({
87
+ 'speaker': speaker,
88
+ 'message': content
89
+ })
90
+
91
+ return normalized_history
92
+
93
+ def get_embedding(self, text: str) -> np.ndarray:
94
+ """Get embedding for text using BAAI/bge-large-en-v1.5"""
95
+ try:
96
+ # Tokenize input
97
+ inputs = self.embedding_tokenizer(
98
+ text,
99
+ padding=True,
100
+ truncation=True,
101
+ return_tensors='pt',
102
+ max_length=8192
103
+ ).to(self.device)
104
+
105
+ # Get model outputs
106
+ with torch.no_grad():
107
+ model_output = self.embedding_model(**inputs)
108
+ # Get sentence embeddings from the model (mean pooling)
109
+ embeddings = model_output.last_hidden_state
110
+ attention_mask = inputs['attention_mask']
111
+
112
+ # Apply mean pooling
113
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
114
+ sum_embeddings = torch.sum(embeddings * input_mask_expanded, 1)
115
+ sum_mask = input_mask_expanded.sum(1)
116
+
117
+ # Avoid division by zero
118
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
119
+ mean_embeddings = sum_embeddings / sum_mask
120
+
121
+ # Normalize embeddings
122
+ embeddings = torch.nn.functional.normalize(mean_embeddings, p=2, dim=1)
123
+
124
+ # Move to CPU and convert to numpy
125
+ bge_embedding = embeddings.cpu().numpy()[0].astype(np.float32)
126
+
127
+ # BGE-large outputs 1024 dimensions by default
128
+ logger.info(f"BGE embedding shape: {bge_embedding.shape}")
129
+
130
+ # Ensure we have exactly 1024 dimensions
131
+ if len(bge_embedding) != 1024:
132
+ logger.warning(f"Expected 1024 dimensions, got {len(bge_embedding)}")
133
+ # Pad or truncate to 1024
134
+ if len(bge_embedding) < 1024:
135
+ padded = np.zeros(1024, dtype=np.float32)
136
+ padded[:len(bge_embedding)] = bge_embedding
137
+ bge_embedding = padded
138
+ else:
139
+ bge_embedding = bge_embedding[:1024]
140
+
141
+ return bge_embedding
142
+
143
+ except Exception as e:
144
+ logger.error(f"Error getting embedding: {str(e)}")
145
+ # Return zeros as fallback with expected dimensions
146
+ return np.zeros(1024, dtype=np.float32)
147
+
148
+ def analyze_conversation_metrics(self, history: List[Dict[str, str]]) -> Dict[str, float]:
149
+ """Analyze conversation to extract key metrics using LLM"""
150
+ try:
151
+ # Normalize history format first
152
+ normalized_history = self._normalize_history_format(history)
153
+
154
+ # Format conversation for analysis
155
+ conversation_text = ""
156
+ for msg in normalized_history:
157
+ speaker = msg.get('speaker', '')
158
+ message = msg.get('message', '')
159
+ conversation_text += f"{speaker}: {message}\n\n"
160
+
161
+ # Create prompt for metrics analysis
162
+ prompt = f"""Analyze this sales conversation and rate each metric from 0.0 to 1.0:
163
+
164
+ customer_engagement:
165
+ sales_effectiveness:
166
+
167
+ Respond only with numbers in the format shown above.
168
+
169
+ Conversation:
170
+ {conversation_text}"""
171
+
172
+ # Get analysis from LLM
173
+ response = self.generate_llm_response(prompt, max_new_tokens=50)
174
+ print("response", response)
175
+
176
+ # Parse metrics
177
+ lines = response.strip().split('\n')
178
+ print("lines", lines)
179
+
180
+ engagement = 0.5
181
+ effectiveness = 0.5
182
+
183
+ for line in lines:
184
+ if 'customer_engagement' in line.lower():
185
+ try:
186
+ engagement = float(line.split(':')[-1].strip())
187
+ # Ensure it's between 0 and 1
188
+ engagement = max(0.0, min(1.0, engagement))
189
+ except:
190
+ pass
191
+ elif 'sales_effectiveness' in line.lower():
192
+ try:
193
+ effectiveness = float(line.split(':')[-1].strip())
194
+ # Ensure it's between 0 and 1
195
+ effectiveness = max(0.0, min(1.0, effectiveness))
196
+ except:
197
+ pass
198
+
199
+ return {
200
+ 'customer_engagement': engagement,
201
+ 'sales_effectiveness': effectiveness,
202
+ 'conversation_length': len(normalized_history),
203
+ 'outcome': 0.5, # Unknown at inference time
204
+ 'progress': min(1.0, len(normalized_history) / 20)
205
+ }
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error analyzing conversation: {str(e)}")
209
+ # Return default values
210
+ return {
211
+ 'customer_engagement': 0.5,
212
+ 'sales_effectiveness': 0.5,
213
+ 'conversation_length': len(history),
214
+ 'outcome': 0.5,
215
+ 'progress': min(1.0, len(history) / 20)
216
+ }
217
+
218
+ def generate_llm_response(self, prompt: str, max_new_tokens: int = 2048) -> str:
219
+ """Generate response using llama-cpp"""
220
+ try:
221
+ # Generate response
222
+ response = self.llm(
223
+ prompt,
224
+ max_tokens=max_new_tokens,
225
+ temperature=0.001,
226
+ top_p=0.95,
227
+ repeat_penalty=1.1,
228
+ stop=["User:", "Assistant:", "\n\n"]
229
+ )
230
+
231
+ # Extract generated text
232
+ generated_text = response['choices'][0]['text']
233
+
234
+ # Clean up the response
235
+ generated_text = generated_text.strip()
236
+
237
+ return generated_text
238
+
239
+ except Exception as e:
240
+ logger.error(f"Error generating LLM response: {str(e)}")
241
+ return "I apologize, but I encountered an error generating a response."
242
+
243
+ def create_state_vector(self,
244
+ embedding: np.ndarray,
245
+ metrics: Dict[str, float],
246
+ turn_number: int,
247
+ previous_probs: List[float]) -> np.ndarray:
248
+ """Create state vector for model input"""
249
+
250
+ # Create metric array (ensure all 5 metrics are included)
251
+ metric_values = np.array([
252
+ metrics['customer_engagement'],
253
+ metrics['sales_effectiveness'],
254
+ metrics['conversation_length'],
255
+ metrics['outcome'],
256
+ metrics['progress']
257
+ ], dtype=np.float32)
258
+
259
+ # Create turn info
260
+ turn_info = np.array([turn_number], dtype=np.float32)
261
+
262
+ # Pad probability history
263
+ padded_probs = np.zeros(10, dtype=np.float32)
264
+ if previous_probs:
265
+ # Handle the case where previous_probs might have more than 10 elements
266
+ recent_probs = previous_probs[-10:] if len(previous_probs) > 10 else previous_probs
267
+ padded_probs[:len(recent_probs)] = recent_probs
268
+
269
+ # Keep original 1024-dimensional embedding without expanding
270
+ if len(embedding) != 1024:
271
+ logger.warning(f"Unexpected embedding size: {len(embedding)}. Expected 1024. Creating zero embedding.")
272
+ embedding = np.zeros(1024, dtype=np.float32)
273
+
274
+ # Total expected: 1024 + 5 + 1 + 10 = 1040
275
+ combined = np.concatenate([
276
+ embedding, # 1024 dimensions
277
+ metric_values, # 5 dimensions
278
+ turn_info, # 1 dimension
279
+ padded_probs # 10 dimensions
280
+ ])
281
+
282
+ logger.info(f"State vector shape: {combined.shape} (expected: 1040)")
283
+ return combined
284
+
285
+ def predict_conversion(self, conversation_id: str, history: List[Dict[str, str]],
286
+ new_response: str) -> float:
287
+ """Predict conversion probability for a conversation"""
288
+ logger.info(f"Predicting conversion for conversation {conversation_id}")
289
+
290
+ # Normalize history format
291
+ normalized_history = self._normalize_history_format(history)
292
+
293
+ # Update history with new response
294
+ updated_history = normalized_history.copy()
295
+ updated_history.append({'speaker': 'sales_rep', 'message': new_response})
296
+
297
+ # Get full conversation text for embedding
298
+ full_text = " ".join([msg.get('message', '') for msg in updated_history])
299
+
300
+ # Get embedding (1024 dimensions)
301
+ embedding = self.get_embedding(full_text)
302
+ logger.info(f"Embedding shape: {embedding.shape}")
303
+
304
+ # Analyze conversation with updated history
305
+ metrics = self.analyze_conversation_metrics(updated_history)
306
+ logger.info(f"Metrics: engagement={metrics['customer_engagement']:.2f}, effectiveness={metrics['sales_effectiveness']:.2f}")
307
+
308
+ # Get turn number (each conversation turn includes user + assistant)
309
+ turn = len(updated_history) // 2
310
+
311
+ # Get previous probabilities
312
+ if conversation_id in self.conversation_states:
313
+ previous_probs = self.conversation_states[conversation_id]['probabilities']
314
+ else:
315
+ previous_probs = [0.5] # Initial probability
316
+
317
+ # Create state vector
318
+ state_vector = self.create_state_vector(embedding, metrics, turn, previous_probs)
319
+
320
+ # Convert to numpy array if it's not already
321
+ if isinstance(state_vector, torch.Tensor):
322
+ state_vector = state_vector.cpu().numpy()
323
+
324
+ # Ensure it's a numpy array
325
+ state_vector = np.array(state_vector, dtype=np.float32)
326
+
327
+ # Log the final shape
328
+ logger.info(f"Final state vector shape: {state_vector.shape}")
329
+
330
+ # Predict using PPO model
331
+ try:
332
+ # Fix deprecation warning by extracting scalar properly
333
+ action, _ = self.ppo_model.predict(state_vector, deterministic=True)
334
+
335
+ # Extract the scalar value
336
+ if hasattr(action, 'item'):
337
+ predicted_prob = float(action.item())
338
+ elif isinstance(action, np.ndarray):
339
+ predicted_prob = float(action[0])
340
+ else:
341
+ predicted_prob = float(action)
342
+
343
+ # Ensure probability is between 0 and 1
344
+ predicted_prob = max(0.0, min(1.0, predicted_prob))
345
+
346
+ except Exception as e:
347
+ logger.error(f"Error during prediction: {str(e)}")
348
+ # Fallback prediction
349
+ predicted_prob = 0.5
350
+
351
+ # Update state
352
+ self.conversation_states[conversation_id] = {
353
+ 'history': updated_history,
354
+ 'probabilities': previous_probs + [predicted_prob]
355
+ }
356
+
357
+ logger.info(f"Predicted conversion probability: {predicted_prob:.4f}")
358
+ return predicted_prob
359
+
360
+ def generate_response(self, conversation_id: str, history: List[Dict[str, str]],
361
+ user_input: str, system_prompt: str = None) -> str:
362
+ """Generate a response using llama-cpp and add conversion probability"""
363
+
364
+ # Normalize history format
365
+ normalized_history = self._normalize_history_format(history)
366
+
367
+ # Format conversation for the LLM
368
+ messages = []
369
+
370
+ # Add system prompt if provided
371
+ if system_prompt:
372
+ messages.append(f"System: {system_prompt}\n")
373
+ else:
374
+ messages.append("System: You are a helpful sales assistant.\n")
375
+
376
+ # Add conversation history
377
+ for msg in normalized_history:
378
+ speaker = msg.get('speaker', '')
379
+ message = msg.get('message', '')
380
+
381
+ if speaker == 'user':
382
+ messages.append(f"User: {message}\n")
383
+ elif speaker == 'sales_rep':
384
+ messages.append(f"Assistant: {message}\n")
385
+
386
+ # Add the latest user input
387
+ messages.append(f"User: {user_input}\n")
388
+ messages.append("Assistant: ")
389
+
390
+ # Create prompt
391
+ prompt = "".join(messages)
392
+
393
+ # Generate LLM response
394
+ llm_response = self.generate_llm_response(prompt, max_new_tokens=2048)
395
+ print(llm_response)
396
+
397
+ # Add user message to history for prediction
398
+ history_with_user = history.copy()
399
+ history_with_user.append({'role': 'user', 'content': user_input})
400
+
401
+ # Predict conversion probability
402
+ probability = self.predict_conversion(conversation_id, history_with_user, llm_response)
403
+
404
+ # Format response with probability
405
+ formatted_response = self.format_response_with_probability(llm_response, probability)
406
+
407
+ return formatted_response
408
+
409
+ def format_response_with_probability(self, response: str, probability: float) -> str:
410
+ """Format response with conversion probability"""
411
+ probability_pct = probability * 100
412
+
413
+ if probability >= 0.38:
414
+ indicator = "🟒 Conversion Highly Likely"
415
+ elif probability >= 0.37:
416
+ indicator = "🟑 Good Conversion Potential"
417
+ elif probability >= 0.35:
418
+ indicator = "🟠 Moderate Conversion Potential"
419
+ else:
420
+ indicator = "πŸ”΄ Conversion Unlikely"
421
+
422
+ formatted_response = (
423
+ f"{response}\n\n"
424
+ f"---\n"
425
+ f"{indicator} ({probability_pct:.1f}%)\n"
426
+ )
427
+
428
+ return formatted_response
429
+
430
+ def format_prediction_result(self, probability: float) -> Dict[str, str]:
431
+ """Format prediction result with status and suggestion"""
432
+ probability_pct = probability * 100
433
+
434
+ if probability >= 0.38:
435
+ status = "🟒 Conversion Highly Likely"
436
+ suggestion = "Follow up with specific next steps or a call to action."
437
+ elif probability >= 0.37:
438
+ status = "🟑 Good Conversion Potential"
439
+ suggestion = "Address any remaining concerns and guide toward a decision."
440
+ elif probability >= 0.35:
441
+ status = "🟠 Moderate Conversion Potential"
442
+ suggestion = "Focus on building value and addressing objections."
443
+ else:
444
+ status = "πŸ”΄ Conversion Unlikely"
445
+ suggestion = "Reframe the conversation or qualify needs better."
446
+
447
+ return {
448
+ "probability": probability,
449
+ "formatted_probability": f"{probability_pct:.1f}%",
450
+ "status": status,
451
+ "suggestion": suggestion
452
+ }
453
+
454
+
455
+ # Example usage
456
+ if __name__ == "__main__":
457
+ # Initialize predictor with GGUF model
458
+ predictor = SalesConversionPredictor(
459
+ model_path="/content/sales-conversion-model-reinf-learning/sales_conversion_model", # path to the model
460
+ embedding_model_name="BAAI/bge-m3",
461
+ llm_gguf_path="unsloth/gemma-3-4b-it-GGUF", # Update this path!
462
+ use_gpu=True,
463
+ n_gpu_layers=20, # Use all layers on GPU
464
+ n_ctx=2048, # Context window size
465
+ use_mini_embeddings=True # Set to match how the model was trained
466
+ )
467
+
468
+ # Test with different conversation scenarios
469
+ scenarios = [
470
+ {
471
+ "id": "negative_outcome",
472
+ "history": [
473
+ {"role": "user", "content": "I'm looking for a CRM solution for my startup."},
474
+ {"role": "assistant", "content": "I'd be happy to help you find the right CRM solution. What's the size of your team and what are your main requirements?"},
475
+ {"role": "user", "content": "We're a team of 10 and need lead management and email automation."},
476
+ {"role": "assistant", "content": "Our CRM offers excellent lead management and built-in email automation that would be perfect for a team of 10. Let me show you how it works."},
477
+ {"role": "user", "content": "not interested, bye"}
478
+ ],
479
+ "response": "ok, thank you for the interest"
480
+ },
481
+ {
482
+ "id": "positive_outcome",
483
+ "history": [
484
+ {"role": "user", "content": "I need a project management tool urgently."},
485
+ {"role": "assistant", "content": "I can definitely help you with that! Our tool is designed for quick implementation. What's your main priority?"},
486
+ {"role": "user", "content": "We need to track tasks and deadlines for 20 people."},
487
+ {"role": "assistant", "content": "Perfect! Our solution handles that easily with real-time collaboration features. We can get you set up today with a free trial."},
488
+ {"role": "user", "content": "That sounds great! What's the pricing?"}
489
+ ],
490
+ "response": "For a team of 20, it's $299/month with all features included. You get 14 days free to test everything. Shall I send you the signup link?"
491
+ },
492
+ {
493
+ "id": "neutral_outcome",
494
+ "history": [
495
+ {"role": "user", "content": "Tell me about your software."},
496
+ {"role": "assistant", "content": "Our software helps businesses manage their operations more efficiently. What specific area are you looking to improve?"},
497
+ {"role": "user", "content": "Just browsing for now."}
498
+ ],
499
+ "response": "No problem! Feel free to explore our website for more information, and I'm here if you have any questions."
500
+ }
501
+ ]
502
+
503
+ # Test each scenario
504
+ for scenario in scenarios:
505
+ print(f"\n=== Testing Scenario: {scenario['id']} ===")
506
+
507
+ # Predict conversion probability
508
+ probability = predictor.predict_conversion(
509
+ conversation_id=scenario['id'],
510
+ history=scenario['history'],
511
+ new_response=scenario['response']
512
+ )
513
+
514
+ # Get formatted result
515
+ result = predictor.format_prediction_result(probability)
516
+
517
+ # Print results
518
+ print(f"Response: {scenario['response']}")
519
+ print(f"Probability: {result['formatted_probability']}")
520
+ print(f"Status: {result['status']}")
521
+ print(f"Suggestion: {result['suggestion']}")
522
+ print("-" * 50)