mgbam commited on
Commit
659fc7b
·
verified ·
1 Parent(s): 6c5e8d4

Update api_clients.py

Browse files
Files changed (1) hide show
  1. api_clients.py +18 -7
api_clients.py CHANGED
@@ -21,15 +21,26 @@ from web_extraction import extract_website_content, enhance_query_with_search
21
 
22
  # HF Inference Client
23
  HF_TOKEN = os.getenv('HF_TOKEN')
 
 
24
 
25
  def get_inference_client(model_id):
26
- """Return an InferenceClient with provider based on model_id."""
27
- provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
28
- return InferenceClient(
29
- provider=provider,
30
- api_key=HF_TOKEN,
31
- bill_to="huggingface"
32
- )
 
 
 
 
 
 
 
 
 
33
 
34
  # Tavily Search Client
35
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
 
21
 
22
  # HF Inference Client
23
  HF_TOKEN = os.getenv('HF_TOKEN')
24
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
25
+ FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY')
26
 
27
  def get_inference_client(model_id):
28
+ """Return an InferenceClient configured for Hugging Face, Groq, or Fireworks AI."""
29
+ if model_id == "moonshotai/Kimi-K2-Instruct":
30
+ return InferenceClient(
31
+ base_url="https://api.groq.com/openai/v1",
32
+ api_key=GROQ_API_KEY
33
+ )
34
+ elif model_id.startswith("fireworks/"):
35
+ return InferenceClient(
36
+ base_url="https://api.fireworks.ai/inference/v1",
37
+ api_key=FIREWORKS_API_KEY
38
+ )
39
+ else:
40
+ return InferenceClient(
41
+ model=model_id,
42
+ api_key=HF_TOKEN
43
+ )
44
 
45
  # Tavily Search Client
46
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')