sparkleman commited on
Commit
271e92e
·
1 Parent(s): 6706d54

FIX: cpu fallback

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -6,11 +6,6 @@ from snowflake import SnowflakeGenerator
6
 
7
  CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
8
 
9
- from pynvml import *
10
-
11
- nvmlInit()
12
- gpu_h = nvmlDeviceGetHandleByIndex(0)
13
-
14
  from typing import List, Optional, Union
15
  from pydantic import BaseModel, Field
16
  from pydantic_settings import BaseSettings
@@ -40,6 +35,17 @@ import numpy as np
40
  import torch
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
43
  torch.backends.cudnn.benchmark = True
44
  torch.backends.cudnn.allow_tf32 = True
45
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -520,19 +526,17 @@ async def chatResponseStream(
520
  yield "data: [DONE]\n\n"
521
 
522
 
523
-
524
-
525
-
526
  @app.post("/api/v1/chat/completions")
527
  async def chat_completions(request: ChatCompletionRequest):
528
  completionId = str(next(CompletionIdGenerator))
529
  logger.info(f"[REQ] {completionId} - {request.model_dump()}")
530
 
531
- def chatResponseStreamDisconnect():
532
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
533
- logger.info(
534
- f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}"
535
- )
 
536
 
537
  model_state = None
538
 
@@ -545,7 +549,6 @@ async def chat_completions(request: ChatCompletionRequest):
545
  else:
546
  r = await chatResponse(request, model_state, completionId)
547
 
548
-
549
  return r
550
 
551
 
 
6
 
7
  CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
8
 
 
 
 
 
 
9
  from typing import List, Optional, Union
10
  from pydantic import BaseModel, Field
11
  from pydantic_settings import BaseSettings
 
35
  import torch
36
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ if device == "cpu" and CONFIG.STRATEGY != "cpu":
40
+ logger.info(f"Cuda not found, fall back to cpu")
41
+ CONFIG.STRATEGY = "cpu"
42
+
43
+ if "cuda" in CONFIG.STRATEGY:
44
+ from pynvml import *
45
+
46
+ nvmlInit()
47
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
48
+
49
  torch.backends.cudnn.benchmark = True
50
  torch.backends.cudnn.allow_tf32 = True
51
  torch.backends.cuda.matmul.allow_tf32 = True
 
526
  yield "data: [DONE]\n\n"
527
 
528
 
 
 
 
529
  @app.post("/api/v1/chat/completions")
530
  async def chat_completions(request: ChatCompletionRequest):
531
  completionId = str(next(CompletionIdGenerator))
532
  logger.info(f"[REQ] {completionId} - {request.model_dump()}")
533
 
534
+ def chatResponseStreamDisconnect():
535
+ if "cuda" in CONFIG.STRATEGY:
536
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
537
+ logger.info(
538
+ f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}"
539
+ )
540
 
541
  model_state = None
542
 
 
549
  else:
550
  r = await chatResponse(request, model_state, completionId)
551
 
 
552
  return r
553
 
554