ash-171 commited on
Commit
7bf0815
·
verified ·
1 Parent(s): 5f0fa37

Update src/app/main_agent.py

Browse files
Files changed (1) hide show
  1. src/app/main_agent.py +13 -5
src/app/main_agent.py CHANGED
@@ -56,11 +56,11 @@ import re
56
  import torch
57
  from transformers import pipeline
58
  import os
 
 
59
 
60
- model_id = "google/gemma-3-4b-it"
61
-
62
- # Load the Gemma 3 model pipeline once
63
- pipe = pipeline("text-generation", model=model_id, use_auth_token=os.getenv("HF_TOKEN"))
64
 
65
  def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
66
  accent_tool = Tool(
@@ -101,7 +101,15 @@ def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
101
  },
102
  ],
103
  ]
104
- outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
 
 
 
 
 
 
 
 
105
  response_text = outputs[0]['generated_text']
106
 
107
  return AIMessage(content=response_text)
 
56
  import torch
57
  from transformers import pipeline
58
  import os
59
+ # Load model directly
60
+ from transformers import AutoTokenizer, AutoModelForCausalLM
61
 
62
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
63
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")
 
 
64
 
65
  def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
66
  accent_tool = Tool(
 
101
  },
102
  ],
103
  ]
104
+ inputs = tokenizer.apply_chat_template(
105
+ messages,
106
+ add_generation_prompt=True,
107
+ tokenize=True,
108
+ return_dict=True,
109
+ return_tensors="pt",
110
+ )
111
+ outputs = model.generate(**inputs, max_new_tokens=64)
112
+ outputs = tokenizer.batch_decode(outputs)
113
  response_text = outputs[0]['generated_text']
114
 
115
  return AIMessage(content=response_text)