Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import time | |
| import torch | |
| from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, | |
| OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, | |
| XLNetLMHeadModel, XLNetTokenizer, | |
| TransfoXLLMHeadModel, TransfoXLTokenizer, | |
| CTRLLMHeadModel, CTRLTokenizer) | |
| model_metadata = { | |
| "gpt2/small": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 550, | |
| "checkpoint": "gpt2", | |
| "identifier": "gpt2/small" | |
| }, "gpt": { | |
| "tokenizer": OpenAIGPTTokenizer, | |
| "model": OpenAIGPTLMHeadModel, | |
| "size": 550, | |
| "checkpoint": "openai-community/openai-gpt", | |
| "identifier": "gpt" | |
| }, "xlnet": { | |
| "tokenizer": XLNetTokenizer, | |
| "model": XLNetLMHeadModel, | |
| "size": 550, | |
| "checkpoint": "xlnet-base-cased", | |
| "identifier": "xlnet" | |
| }, "gpt2/arxiv-nlp": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 550, | |
| "checkpoint": "arxiv-nlp-v1", | |
| "identifier": "gpt2/arxiv-nlp" | |
| }, "gpt2/medium": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 1500, | |
| "checkpoint": "openai-community/gpt2-medium", | |
| "identifier": "gpt2/medium" | |
| }, "gpt2/large": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 3300, | |
| "checkpoint": "openai-community/gpt2-large", | |
| "identifier": "gpt2/large" | |
| }, "distilgpt2/small": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 350, | |
| "checkpoint": "distilgpt2", | |
| "identifier": "distilgpt2/small" | |
| }, "ctrl": { | |
| "tokenizer": CTRLTokenizer, | |
| "model": CTRLLMHeadModel, | |
| "size": 6300, | |
| "checkpoint": "Salesforce/ctrl", | |
| "identifier": "ctrl" | |
| }, "pplm": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 3000, | |
| "checkpoint": "openai-community/gpt2-large", | |
| "identifier": "pplm" | |
| }, "gpt2/xl": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 7000, | |
| "checkpoint": "openai-community/gpt2-xl", | |
| "identifier": "gpt2/xl" | |
| }, "pplm": { | |
| "tokenizer": GPT2Tokenizer, | |
| "model": GPT2LMHeadModel, | |
| "size": 4000, | |
| "checkpoint": "openai-community/gpt2-medium", | |
| "identifier": "pplm", | |
| "configuration_options": { | |
| "config": GPT2Config, | |
| "options": { | |
| "output_hidden_states": True | |
| } | |
| } | |
| } | |
| } | |
| memory_overhead = 500 | |
| class GPU: | |
| def __init__(self, id): | |
| self.id = id | |
| self.models = [] | |
| self.total_memory = torch.cuda.get_device_properties( | |
| "cuda:{}".format(id)).total_memory / 1_000_000 - 1_000 | |
| print("INIT GPU WITH DEVICE", "cuda:{}".format(id)) | |
| def register_model(self, model, cached_path=None): | |
| if self.total_memory_used() + model["size"] < self.total_memory: | |
| model["device"] = "cuda:{}".format(self.id) | |
| if cached_path: | |
| model["cached_path"] = cached_path | |
| self.models.append(model) | |
| return True | |
| else: | |
| return False | |
| def total_memory_used(self): | |
| return sum([model["size"] for model in self.models]) + memory_overhead | |
| def __repr__(self): | |
| return str( | |
| [(model["checkpoint"], model["size"]) for model in self.models] + | |
| [str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] + | |
| ["cuda:{}".format(self.id)] | |
| ) | |
| class GPUHandler: | |
| def __init__(self, ids, model_list, gpu_ids, cached_models=None): | |
| if cached_models is None: | |
| cached_models = {} | |
| self.gpus = [GPU(id) for id in gpu_ids] | |
| print("GPU handler initiated with {} gpus.".format(len(self.gpus))) | |
| self.sanity_check([model_metadata[model] for model in model_list]) | |
| for model in model_list: | |
| self.register_model(model_metadata[model], cached_models.get(model)) | |
| def register_model(self, model, cached_path=None): | |
| for index, gpu in enumerate(self.gpus): | |
| if gpu.register_model(model, cached_path): | |
| print("Registered model", model, "in GPU", gpu) | |
| break | |
| if index >= len(self.gpus): | |
| raise ValueError("Could not load model", model["checkpoint"]) | |
| def sanity_check(self, model_list): | |
| temp_gpus = [GPU(id) for id in range(len(self.gpus))] | |
| for model in model_list: | |
| current_gpu_index = 0 | |
| while current_gpu_index < len(temp_gpus): | |
| if not temp_gpus[current_gpu_index].register_model(model): | |
| current_gpu_index += 1 | |
| else: | |
| break | |
| if current_gpu_index >= len(temp_gpus): | |
| raise RuntimeError("SANITY CHECK FAILED") | |
| print("Current layout", temp_gpus) | |
| def __repr__(self): | |
| return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}" | |