Spaces:
Sleeping
Sleeping
Fix device placement and memory handling
Browse files
app.py
CHANGED
@@ -83,17 +83,19 @@ class ModelManager:
|
|
83 |
|
84 |
# Load the fine-tuned model
|
85 |
logger.info("Loading fine-tuned model (this may take a few minutes)...")
|
|
|
|
|
|
|
86 |
self.model = T5ForConditionalGeneration.from_pretrained(
|
87 |
"pdarleyjr/iplc-t5-model",
|
88 |
config=config,
|
89 |
-
|
90 |
-
torch_dtype=torch.float16,
|
91 |
low_cpu_mem_usage=True
|
92 |
-
)
|
93 |
logger.success("Model loaded successfully")
|
94 |
|
95 |
# Prepare model with accelerator
|
96 |
-
self.model = self.accelerator.
|
97 |
logger.success("Model prepared with accelerator")
|
98 |
|
99 |
# Log final memory usage
|
@@ -173,7 +175,10 @@ async def predict(request: PredictRequest) -> JSONResponse:
|
|
173 |
|
174 |
# Generate summary with error handling
|
175 |
try:
|
176 |
-
|
|
|
|
|
|
|
177 |
outputs = model_manager.model.generate(
|
178 |
input_ids,
|
179 |
max_length=256,
|
|
|
83 |
|
84 |
# Load the fine-tuned model
|
85 |
logger.info("Loading fine-tuned model (this may take a few minutes)...")
|
86 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
87 |
+
logger.info(f"Using device: {device}")
|
88 |
+
|
89 |
self.model = T5ForConditionalGeneration.from_pretrained(
|
90 |
"pdarleyjr/iplc-t5-model",
|
91 |
config=config,
|
92 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
|
93 |
low_cpu_mem_usage=True
|
94 |
+
).to(device)
|
95 |
logger.success("Model loaded successfully")
|
96 |
|
97 |
# Prepare model with accelerator
|
98 |
+
self.model = self.accelerator.prepare_model(self.model)
|
99 |
logger.success("Model prepared with accelerator")
|
100 |
|
101 |
# Log final memory usage
|
|
|
175 |
|
176 |
# Generate summary with error handling
|
177 |
try:
|
178 |
+
device = next(model_manager.model.parameters()).device
|
179 |
+
input_ids = input_ids.to(device)
|
180 |
+
|
181 |
+
with torch.no_grad(), model_manager.accelerator.autocast():
|
182 |
outputs = model_manager.model.generate(
|
183 |
input_ids,
|
184 |
max_length=256,
|