Coool2 commited on
Commit
b2c7d9d
·
1 Parent(s): 3b3d057

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +7 -10
agent.py CHANGED
@@ -66,15 +66,9 @@ import weave
66
 
67
  weave.init("gaia-llamaindex-agents")
68
 
69
- def get_max_memory_config(max_memory_per_gpu):
70
- """Generate max_memory config for available GPUs"""
71
- if torch.cuda.is_available():
72
- num_gpus = torch.cuda.device_count()
73
- max_memory = {}
74
- for i in range(num_gpus):
75
- max_memory[i] = max_memory_per_gpu
76
- return max_memory
77
- return None
78
 
79
  # Initialize models based on API availability
80
  def initialize_models(use_api_mode=False):
@@ -123,7 +117,10 @@ def initialize_models(use_api_mode=False):
123
  tokenizer_name="google/gemma-3-12b-it",
124
  device_map="auto",
125
  max_new_tokens=16000,
126
- model_kwargs={"torch_dtype": "auto"},
 
 
 
127
  generate_kwargs={
128
  "temperature": 0.6,
129
  "top_p": 0.95,
 
66
 
67
  weave.init("gaia-llamaindex-agents")
68
 
69
+ from transformers import get_max_memory
70
+
71
+ max_mem = get_max_memory(0.9) # 90% of each device's memory
 
 
 
 
 
 
72
 
73
  # Initialize models based on API availability
74
  def initialize_models(use_api_mode=False):
 
117
  tokenizer_name="google/gemma-3-12b-it",
118
  device_map="auto",
119
  max_new_tokens=16000,
120
+ model_kwargs={
121
+ "torch_dtype": "auto",
122
+ "max_memory": max_mem, # Add this line
123
+ },
124
  generate_kwargs={
125
  "temperature": 0.6,
126
  "top_p": 0.95,