Spaces:
Running
Running
| # ============================================================================= | |
| # routing/tlm_manager.py | |
| # ============================================================================= | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Dict, Tuple, Optional | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import asyncio | |
| from core.model import MambaModel | |
| from core.config import MambaConfig | |
| from utils.domain_configs import DomainConfigs | |
| class SpecialistTLM: | |
| """Individual Specialist Mamba TLM""" | |
| def __init__(self, specialist_id: int, config: MambaConfig, domain_info: Dict): | |
| self.specialist_id = specialist_id | |
| self.config = config | |
| self.domain_info = domain_info | |
| self.model = MambaModel(config) | |
| self.device = config.device | |
| # Move to device | |
| self.model.to(self.device) | |
| def encode(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| """Encode input and return hidden states""" | |
| self.model.eval() | |
| with torch.no_grad(): | |
| # Get embeddings | |
| x = self.model.embedding(input_ids) | |
| # Pass through Mamba layers | |
| for layer in self.model.layers: | |
| x = layer(x) | |
| # Apply final norm | |
| x = self.model.norm_f(x) | |
| # Return pooled representation | |
| return x.mean(dim=1) # [batch, d_model] | |
| def get_memory_usage(self) -> int: | |
| """Get model memory usage in bytes""" | |
| return sum(p.numel() * p.element_size() for p in self.model.parameters()) | |
| class TLMManager: | |
| """Manages 100 specialist Mamba TLMs""" | |
| def __init__(self, config: MambaConfig): | |
| self.config = config | |
| self.device = config.device | |
| # Create domain configurations | |
| self.domain_configs = DomainConfigs.get_domain_configs(config.num_specialists) | |
| # Initialize specialists | |
| self.specialists = {} | |
| self._initialize_specialists() | |
| # Shared components | |
| self.shared_embedding = None | |
| if config.shared_embedding: | |
| self.shared_embedding = nn.Embedding(config.vocab_size, config.d_model) | |
| self.shared_embedding.to(self.device) | |
| # Thread pool for parallel processing | |
| self.executor = ThreadPoolExecutor(max_workers=min(32, config.num_specialists)) | |
| def _initialize_specialists(self): | |
| """Initialize all specialist TLMs""" | |
| print("Initializing 100 specialist TLMs...") | |
| for domain_config in self.domain_configs: | |
| specialist_id = domain_config["id"] | |
| # Create specialist-specific config | |
| specialist_config = DomainConfigs.create_specialist_config( | |
| self.config, specialist_id | |
| ) | |
| # Create specialist TLM | |
| specialist = SpecialistTLM( | |
| specialist_id=specialist_id, | |
| config=specialist_config, | |
| domain_info=domain_config | |
| ) | |
| self.specialists[specialist_id] = specialist | |
| if specialist_id % 10 == 0: | |
| print(f"Initialized {specialist_id + 1}/100 specialists") | |
| print("All specialists initialized!") | |
| # Apply weight sharing if enabled | |
| if self.config.hierarchical_sharing: | |
| self._apply_weight_sharing() | |
| def _apply_weight_sharing(self): | |
| """Apply hierarchical weight sharing between specialists""" | |
| print("Applying hierarchical weight sharing...") | |
| # Share embedding layers | |
| if self.shared_embedding is not None: | |
| for specialist in self.specialists.values(): | |
| specialist.model.embedding.token_embedding = self.shared_embedding | |
| # Group specialists by domain similarity and share lower layers | |
| domain_groups = self._group_domains_by_similarity() | |
| for group in domain_groups: | |
| if len(group) > 1: | |
| # Use first specialist's weights as shared weights for the group | |
| reference_specialist = self.specialists[group[0]] | |
| shared_layers = reference_specialist.model.layers[:self.config.n_layers//2] | |
| for specialist_id in group[1:]: | |
| specialist = self.specialists[specialist_id] | |
| for i, layer in enumerate(shared_layers): | |
| specialist.model.layers[i] = layer | |
| def _group_domains_by_similarity(self) -> List[List[int]]: | |
| """Group domains by similarity for weight sharing""" | |
| # Simple grouping based on domain categories | |
| groups = { | |
| 'stem': [], | |
| 'programming': [], | |
| 'language': [], | |
| 'business': [], | |
| 'other': [] | |
| } | |
| for domain_config in self.domain_configs: | |
| domain_name = domain_config["name"].lower() | |
| specialist_id = domain_config["id"] | |
| if any(x in domain_name for x in ['math', 'physics', 'chemistry', 'biology']): | |
| groups['stem'].append(specialist_id) | |
| elif any(x in domain_name for x in ['python', 'javascript', 'systems']): | |
| groups['programming'].append(specialist_id) | |
| elif any(x in domain_name for x in ['writing', 'translation']): | |
| groups['language'].append(specialist_id) | |
| elif any(x in domain_name for x in ['business', 'legal']): | |
| groups['business'].append(specialist_id) | |
| else: | |
| groups['other'].append(specialist_id) | |
| return [group for group in groups.values() if len(group) > 1] | |
| def encode_parallel(self, routing_results: List[Dict]) -> List[Dict]: | |
| """ | |
| Encode chunks in parallel using appropriate specialists | |
| Args: | |
| routing_results: List of routing results from router | |
| Returns: | |
| List of encoded results with specialist outputs | |
| """ | |
| futures = [] | |
| for chunk_info in routing_results: | |
| chunk_text = chunk_info['text'] | |
| specialists = chunk_info['specialists'] | |
| chunk_id = chunk_info['chunk_id'] | |
| # Create encoding task for each relevant specialist | |
| for specialist_id, confidence in specialists: | |
| if specialist_id in self.specialists: | |
| future = self.executor.submit( | |
| self._encode_chunk, | |
| chunk_text, | |
| specialist_id, | |
| confidence, | |
| chunk_id | |
| ) | |
| futures.append(future) | |
| # Collect results | |
| encoded_results = [] | |
| for future in as_completed(futures): | |
| try: | |
| result = future.result() | |
| encoded_results.append(result) | |
| except Exception as e: | |
| print(f"Error in specialist encoding: {e}") | |
| # Group results by chunk_id | |
| grouped_results = {} | |
| for result in encoded_results: | |
| chunk_id = result['chunk_id'] | |
| if chunk_id not in grouped_results: | |
| grouped_results[chunk_id] = [] | |
| grouped_results[chunk_id].append(result) | |
| return grouped_results | |
| def _encode_chunk(self, text: str, specialist_id: int, confidence: float, | |
| chunk_id: int) -> Dict: | |
| """Encode a single chunk with a specific specialist""" | |
| try: | |
| specialist = self.specialists[specialist_id] | |
| # Tokenize text (simplified - should use proper tokenizer) | |
| # This is a placeholder - integrate with actual tokenizer | |
| input_ids = torch.randint(0, 1000, (1, 100)).to(self.device) | |
| # Encode with specialist | |
| encoding = specialist.encode(input_ids) | |
| return { | |
| 'chunk_id': chunk_id, | |
| 'specialist_id': specialist_id, | |
| 'confidence': confidence, | |
| 'encoding': encoding, | |
| 'domain': specialist.domain_info['name'] | |
| } | |
| except Exception as e: | |
| print(f"Error encoding chunk {chunk_id} with specialist {specialist_id}: {e}") | |
| return None | |
| def get_active_specialists(self) -> List[int]: | |
| """Get list of currently active specialist IDs""" | |
| return list(self.specialists.keys()) | |
| def get_specialist_info(self, specialist_id: int) -> Dict: | |
| """Get information about a specific specialist""" | |
| if specialist_id in self.specialists: | |
| specialist = self.specialists[specialist_id] | |
| return { | |
| 'id': specialist_id, | |
| 'domain': specialist.domain_info, | |
| 'params': specialist.model.get_num_params(), | |
| 'memory': specialist.get_memory_usage() | |
| } | |
| return None | |
| def get_total_parameters(self) -> int: | |
| """Get total parameters across all specialists""" | |
| total = 0 | |
| for specialist in self.specialists.values(): | |
| total += specialist.model.get_num_params() | |
| return total |