wt002 commited on
Commit
5a5c64e
·
verified ·
1 Parent(s): fddf6b5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +30 -1
agent.py CHANGED
@@ -669,12 +669,41 @@ def process_question(question: str):
669
 
670
 
671
  # Build graph function
672
- def build_graph(provider: str, model_config: dict):
 
 
 
 
 
 
 
 
 
 
673
  from langgraph.prebuilt.tool_node import ToolNode
674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  llm = get_llm(provider, model_config)
676
  llm_with_tools = llm.bind_tools(tools)
677
 
 
 
 
 
 
678
  sys_msg = SystemMessage(content="You are a helpful assistant.")
679
 
680
  def assistant(state: MessagesState):
 
669
 
670
 
671
  # Build graph function
672
+ provider = "huggingface"
673
+
674
+ model_config = {
675
+ "repo_id": "HuggingFaceH4/zephyr-7b-beta",
676
+ "task": "text-generation",
677
+ "temperature": 0.7,
678
+ "max_new_tokens": 512,
679
+ "huggingfacehub_api_token": os.getenv("HF_TOKEN")
680
+ }
681
+
682
+ def build_graph(provider, model_config):
683
  from langgraph.prebuilt.tool_node import ToolNode
684
 
685
+ def get_llm(provider: str, config: dict):
686
+ if provider == "huggingface":
687
+ from langchain_huggingface import HuggingFaceEndpoint
688
+ return HuggingFaceEndpoint(
689
+ repo_id=config["repo_id"],
690
+ task=config["task"],
691
+ huggingfacehub_api_token=config["huggingfacehub_api_token"],
692
+ temperature=config["temperature"],
693
+ max_new_tokens=config["max_new_tokens"]
694
+ )
695
+ else:
696
+ raise ValueError(f"Unsupported provider: {provider}")
697
+
698
+
699
  llm = get_llm(provider, model_config)
700
  llm_with_tools = llm.bind_tools(tools)
701
 
702
+ # Continue building graph logic here...
703
+ # builder = StateGraph(...)
704
+ # return builder.compile()
705
+
706
+
707
  sys_msg = SystemMessage(content="You are a helpful assistant.")
708
 
709
  def assistant(state: MessagesState):