pdarleyjr commited on
Commit
9bad572
·
1 Parent(s): 2e8b75e

Fix device placement and memory handling

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -83,17 +83,19 @@ class ModelManager:
83
 
84
  # Load the fine-tuned model
85
  logger.info("Loading fine-tuned model (this may take a few minutes)...")
 
 
 
86
  self.model = T5ForConditionalGeneration.from_pretrained(
87
  "pdarleyjr/iplc-t5-model",
88
  config=config,
89
- device_map="auto",
90
- torch_dtype=torch.float16,
91
  low_cpu_mem_usage=True
92
- )
93
  logger.success("Model loaded successfully")
94
 
95
  # Prepare model with accelerator
96
- self.model = self.accelerator.prepare(self.model)
97
  logger.success("Model prepared with accelerator")
98
 
99
  # Log final memory usage
@@ -173,7 +175,10 @@ async def predict(request: PredictRequest) -> JSONResponse:
173
 
174
  # Generate summary with error handling
175
  try:
176
- with model_manager.accelerator.autocast():
 
 
 
177
  outputs = model_manager.model.generate(
178
  input_ids,
179
  max_length=256,
 
83
 
84
  # Load the fine-tuned model
85
  logger.info("Loading fine-tuned model (this may take a few minutes)...")
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ logger.info(f"Using device: {device}")
88
+
89
  self.model = T5ForConditionalGeneration.from_pretrained(
90
  "pdarleyjr/iplc-t5-model",
91
  config=config,
92
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 
93
  low_cpu_mem_usage=True
94
+ ).to(device)
95
  logger.success("Model loaded successfully")
96
 
97
  # Prepare model with accelerator
98
+ self.model = self.accelerator.prepare_model(self.model)
99
  logger.success("Model prepared with accelerator")
100
 
101
  # Log final memory usage
 
175
 
176
  # Generate summary with error handling
177
  try:
178
+ device = next(model_manager.model.parameters()).device
179
+ input_ids = input_ids.to(device)
180
+
181
+ with torch.no_grad(), model_manager.accelerator.autocast():
182
  outputs = model_manager.model.generate(
183
  input_ids,
184
  max_length=256,