Sushwetabm commited on
Commit
ebe67a1
Β·
1 Parent(s): c16d4e7

updated model.py

Browse files
Files changed (1) hide show
  1. model.py +80 -145
model.py CHANGED
@@ -1,21 +1,21 @@
1
- # # model.py - Optimized version
2
- # from transformers import AutoTokenizer, AutoModelForCausalLM
3
- # import torch
4
- # from functools import lru_cache
5
- # import os
6
- # import asyncio
7
- # from concurrent.futures import ThreadPoolExecutor
8
- # import logging
9
 
10
- # logger = logging.getLogger(__name__)
11
 
12
- # # Global variables to store loaded model
13
- # _tokenizer = None
14
- # _model = None
15
- # _model_loading = False
16
- # _model_loaded = False
17
 
18
- # @lru_cache(maxsize=1)
19
  # def get_model_config():
20
  # """Cache model configuration"""
21
  # return {
@@ -27,167 +27,102 @@
27
  # "low_cpu_mem_usage": True,
28
  # "use_cache": True,
29
  # }
30
-
31
- # def load_model_sync():
32
- # """Synchronous model loading with optimizations"""
33
- # global _tokenizer, _model, _model_loaded
34
-
35
- # if _model_loaded:
36
- # return _tokenizer, _model
37
-
38
- # config = get_model_config()
39
- # model_id = config["model_id"]
40
-
41
- # logger.info(f"πŸ”§ Loading model {model_id}...")
42
-
43
- # try:
44
- # # Set cache directory to avoid re-downloading
45
- # cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
46
- # os.makedirs(cache_dir, exist_ok=True)
47
-
48
- # # Load tokenizer first (faster)
49
- # logger.info("πŸ“ Loading tokenizer...")
50
- # _tokenizer = AutoTokenizer.from_pretrained(
51
- # model_id,
52
- # trust_remote_code=config["trust_remote_code"],
53
- # cache_dir=cache_dir,
54
- # use_fast=True, # Use fast tokenizer if available
55
- # )
56
-
57
- # # Load model with optimizations
58
- # logger.info("🧠 Loading model...")
59
- # _model = AutoModelForCausalLM.from_pretrained(
60
- # model_id,
61
- # trust_remote_code=config["trust_remote_code"],
62
- # torch_dtype=config["torch_dtype"],
63
- # device_map=config["device_map"],
64
- # low_cpu_mem_usage=config["low_cpu_mem_usage"],
65
- # cache_dir=cache_dir,
66
- # offload_folder="offload",
67
- # offload_state_dict=True
68
- # )
69
-
70
- # # Set to evaluation mode
71
- # _model.eval()
72
-
73
- # _model_loaded = True
74
- # logger.info("βœ… Model loaded successfully!")
75
- # return _tokenizer, _model
76
-
77
- # except Exception as e:
78
- # logger.error(f"❌ Failed to load model: {e}")
79
- # raise
80
-
81
- # async def load_model_async():
82
- # """Asynchronous model loading"""
83
- # global _model_loading
84
-
85
- # if _model_loaded:
86
- # return _tokenizer, _model
87
-
88
- # if _model_loading:
89
- # # Wait for ongoing loading to complete
90
- # while _model_loading and not _model_loaded:
91
- # await asyncio.sleep(0.1)
92
- # return _tokenizer, _model
93
-
94
- # _model_loading = True
95
-
96
- # try:
97
- # # Run model loading in thread pool to avoid blocking
98
- # loop = asyncio.get_event_loop()
99
- # with ThreadPoolExecutor(max_workers=1) as executor:
100
- # tokenizer, model = await loop.run_in_executor(
101
- # executor, load_model_sync
102
- # )
103
- # return tokenizer, model
104
- # finally:
105
- # _model_loading = False
106
-
107
- # def get_model():
108
- # """Get the loaded model (for synchronous access)"""
109
- # if not _model_loaded:
110
- # return load_model_sync()
111
- # return _tokenizer, _model
112
-
113
- # def is_model_loaded():
114
- # """Check if model is loaded"""
115
- # return _model_loaded
116
-
117
- # def get_model_info():
118
- # """Get model information without loading"""
119
- # config = get_model_config()
120
- # return {
121
- # "model_id": config["model_id"],
122
- # "loaded": _model_loaded,
123
- # "loading": _model_loading,
124
- # }
125
-
126
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
127
- from functools import lru_cache
128
- import logging
129
- import asyncio
130
- logger = logging.getLogger(__name__)
131
- _model_loaded = False
132
- _tokenizer = None
133
- _model = None
134
- @lru_cache(maxsize=1)
135
  def get_model_config():
136
  return {
137
  "model_id": "Salesforce/codet5p-220m",
138
  "trust_remote_code": True
139
  }
140
-
141
  def load_model_sync():
 
142
  global _tokenizer, _model, _model_loaded
143
-
144
  if _model_loaded:
145
  return _tokenizer, _model
146
-
147
  config = get_model_config()
148
  model_id = config["model_id"]
149
-
 
 
150
  try:
151
- _tokenizer = AutoTokenizer.from_pretrained(model_id)
152
- _model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  _model.eval()
 
154
  _model_loaded = True
 
155
  return _tokenizer, _model
156
-
157
  except Exception as e:
158
  logger.error(f"❌ Failed to load model: {e}")
159
  raise
160
 
161
-
162
  async def load_model_async():
163
- global _tokenizer, _model, _model_loaded
 
 
164
  if _model_loaded:
165
- return
166
-
167
- config = get_model_config()
168
- model_id = config["model_id"]
169
-
 
 
 
 
 
170
  try:
171
- _tokenizer = AutoTokenizer.from_pretrained(model_id)
172
- _model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
173
- _model.eval()
174
- _model_loaded = True
175
- logger.info(f"βœ… Model {model_id} loaded successfully.")
176
- except Exception as e:
177
- logger.error(f"❌ Failed to load model: {e}")
178
- raise
 
179
 
180
  def get_model():
 
181
  if not _model_loaded:
182
- raise ValueError("Model not loaded yet")
183
  return _tokenizer, _model
184
 
185
  def is_model_loaded():
 
186
  return _model_loaded
187
 
188
  def get_model_info():
 
 
189
  return {
190
- "model_id": get_model_config()["model_id"],
191
  "loaded": _model_loaded,
192
- "loading": not _model_loaded
193
  }
 
1
+ # model.py - Optimized version
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ from functools import lru_cache
5
+ import os
6
+ import asyncio
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import logging
9
 
10
+ logger = logging.getLogger(__name__)
11
 
12
+ # Global variables to store loaded model
13
+ _tokenizer = None
14
+ _model = None
15
+ _model_loading = False
16
+ _model_loaded = False
17
 
18
+ @lru_cache(maxsize=1)
19
  # def get_model_config():
20
  # """Cache model configuration"""
21
  # return {
 
27
  # "low_cpu_mem_usage": True,
28
  # "use_cache": True,
29
  # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_model_config():
31
  return {
32
  "model_id": "Salesforce/codet5p-220m",
33
  "trust_remote_code": True
34
  }
 
35
  def load_model_sync():
36
+ """Synchronous model loading with optimizations"""
37
  global _tokenizer, _model, _model_loaded
38
+
39
  if _model_loaded:
40
  return _tokenizer, _model
41
+
42
  config = get_model_config()
43
  model_id = config["model_id"]
44
+
45
+ logger.info(f"πŸ”§ Loading model {model_id}...")
46
+
47
  try:
48
+ # Set cache directory to avoid re-downloading
49
+ cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
50
+ os.makedirs(cache_dir, exist_ok=True)
51
+
52
+ # Load tokenizer first (faster)
53
+ logger.info("πŸ“ Loading tokenizer...")
54
+ _tokenizer = AutoTokenizer.from_pretrained(
55
+ model_id,
56
+ trust_remote_code=config["trust_remote_code"],
57
+ cache_dir=cache_dir,
58
+ use_fast=True, # Use fast tokenizer if available
59
+ )
60
+
61
+ # Load model with optimizations
62
+ logger.info("🧠 Loading model...")
63
+ _model = AutoModelForCausalLM.from_pretrained(
64
+ model_id,
65
+ trust_remote_code=config["trust_remote_code"],
66
+ torch_dtype=config["torch_dtype"],
67
+ device_map=config["device_map"],
68
+ low_cpu_mem_usage=config["low_cpu_mem_usage"],
69
+ cache_dir=cache_dir,
70
+ offload_folder="offload",
71
+ offload_state_dict=True
72
+ )
73
+
74
+ # Set to evaluation mode
75
  _model.eval()
76
+
77
  _model_loaded = True
78
+ logger.info("βœ… Model loaded successfully!")
79
  return _tokenizer, _model
80
+
81
  except Exception as e:
82
  logger.error(f"❌ Failed to load model: {e}")
83
  raise
84
 
 
85
  async def load_model_async():
86
+ """Asynchronous model loading"""
87
+ global _model_loading
88
+
89
  if _model_loaded:
90
+ return _tokenizer, _model
91
+
92
+ if _model_loading:
93
+ # Wait for ongoing loading to complete
94
+ while _model_loading and not _model_loaded:
95
+ await asyncio.sleep(0.1)
96
+ return _tokenizer, _model
97
+
98
+ _model_loading = True
99
+
100
  try:
101
+ # Run model loading in thread pool to avoid blocking
102
+ loop = asyncio.get_event_loop()
103
+ with ThreadPoolExecutor(max_workers=1) as executor:
104
+ tokenizer, model = await loop.run_in_executor(
105
+ executor, load_model_sync
106
+ )
107
+ return tokenizer, model
108
+ finally:
109
+ _model_loading = False
110
 
111
  def get_model():
112
+ """Get the loaded model (for synchronous access)"""
113
  if not _model_loaded:
114
+ return load_model_sync()
115
  return _tokenizer, _model
116
 
117
  def is_model_loaded():
118
+ """Check if model is loaded"""
119
  return _model_loaded
120
 
121
  def get_model_info():
122
+ """Get model information without loading"""
123
+ config = get_model_config()
124
  return {
125
+ "model_id": config["model_id"],
126
  "loaded": _model_loaded,
127
+ "loading": _model_loading,
128
  }