Elron commited on
Commit
b66f73e
·
verified ·
1 Parent(s): e013bd8

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +18 -6
  2. version.py +1 -1
inference.py CHANGED
@@ -79,7 +79,7 @@ class StandardAPIParamsMixin(Artifact):
79
  n: Optional[int] = None
80
  parallel_tool_calls: Optional[bool] = None
81
  service_tier: Optional[Literal["auto", "default"]] = None
82
- credentials: Optional[Dict[str, str]] = {}
83
  extra_headers: Optional[Dict[str, str]] = None
84
 
85
 
@@ -468,7 +468,7 @@ class LazyLoadMixin(Artifact):
468
 
469
 
470
  class HFGenerationParamsMixin(Artifact):
471
- max_new_tokens: int
472
  do_sample: bool = False
473
  temperature: Optional[float] = None
474
  top_p: Optional[float] = None
@@ -488,6 +488,7 @@ class HFInferenceEngineBase(
488
  TorchDeviceMixin,
489
  ):
490
  model_name: str
 
491
  label: str
492
 
493
  n_top_tokens: int = 5
@@ -710,8 +711,9 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
710
  def _init_processor(self):
711
  from transformers import AutoTokenizer
712
 
 
713
  self.processor = AutoTokenizer.from_pretrained(
714
- pretrained_model_name_or_path=self.model_name,
715
  use_fast=self.use_fast_tokenizer,
716
  )
717
 
@@ -1120,6 +1122,7 @@ class HFPipelineBasedInferenceEngine(
1120
  TorchDeviceMixin,
1121
  ):
1122
  model_name: str
 
1123
  label: str = "hf_pipeline_inference_engine"
1124
 
1125
  use_fast_tokenizer: bool = True
@@ -1217,8 +1220,8 @@ class HFPipelineBasedInferenceEngine(
1217
  path = self.model_name
1218
  if settings.hf_offline_models_path is not None:
1219
  path = os.path.join(settings.hf_offline_models_path, path)
1220
-
1221
- tokenizer = AutoTokenizer.from_pretrained(self.model_name)
1222
  self.model = pipeline(
1223
  model=path,
1224
  task=self.task,
@@ -3359,6 +3362,8 @@ class LiteLLMInferenceEngine(
3359
  return get_model_and_label_id(self.model, self.label)
3360
 
3361
  def prepare_engine(self):
 
 
3362
  # Initialize the token bucket rate limiter
3363
  self._rate_limiter = AsyncTokenBucket(
3364
  rate=self.max_requests_per_second,
@@ -3474,7 +3479,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3474
  user requests.
3475
 
3476
  Current _supported_apis = ["watsonx", "together-ai", "open-ai", "aws", "ollama",
3477
- "bam", "watsonx-sdk", "rits", "vertex-ai"]
3478
 
3479
  Args:
3480
  provider (Optional):
@@ -3681,6 +3686,11 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3681
  "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
3682
  "gpt-4-1": "replicate/openai/gpt-4.1",
3683
  },
 
 
 
 
 
3684
  }
3685
  provider_model_map["watsonx"] = {
3686
  k: f"watsonx/{v}" for k, v in provider_model_map["watsonx-sdk"].items()
@@ -3698,12 +3708,14 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3698
  "azure": LiteLLMInferenceEngine,
3699
  "vertex-ai": LiteLLMInferenceEngine,
3700
  "replicate": LiteLLMInferenceEngine,
 
3701
  }
3702
 
3703
  _provider_param_renaming = {
3704
  "bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
3705
  "watsonx-sdk": {"model": "model_name"},
3706
  "rits": {"model": "model_name"},
 
3707
  }
3708
 
3709
  def get_return_object(self, **kwargs):
 
79
  n: Optional[int] = None
80
  parallel_tool_calls: Optional[bool] = None
81
  service_tier: Optional[Literal["auto", "default"]] = None
82
+ credentials: Optional[Dict[str, str]] = None
83
  extra_headers: Optional[Dict[str, str]] = None
84
 
85
 
 
468
 
469
 
470
  class HFGenerationParamsMixin(Artifact):
471
+ max_new_tokens: Optional[int] = None
472
  do_sample: bool = False
473
  temperature: Optional[float] = None
474
  top_p: Optional[float] = None
 
488
  TorchDeviceMixin,
489
  ):
490
  model_name: str
491
+ tokenizer_name: Optional[str] = None
492
  label: str
493
 
494
  n_top_tokens: int = 5
 
711
  def _init_processor(self):
712
  from transformers import AutoTokenizer
713
 
714
+ tokenizer_name = self.tokenizer_name or self.model_name
715
  self.processor = AutoTokenizer.from_pretrained(
716
+ pretrained_model_name_or_path=tokenizer_name,
717
  use_fast=self.use_fast_tokenizer,
718
  )
719
 
 
1122
  TorchDeviceMixin,
1123
  ):
1124
  model_name: str
1125
+ tokenizer_name: Optional[str] = None
1126
  label: str = "hf_pipeline_inference_engine"
1127
 
1128
  use_fast_tokenizer: bool = True
 
1220
  path = self.model_name
1221
  if settings.hf_offline_models_path is not None:
1222
  path = os.path.join(settings.hf_offline_models_path, path)
1223
+ tokenizer_name = self.tokenizer_name or self.model_name
1224
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
1225
  self.model = pipeline(
1226
  model=path,
1227
  task=self.task,
 
3362
  return get_model_and_label_id(self.model, self.label)
3363
 
3364
  def prepare_engine(self):
3365
+ if self.credentials is None:
3366
+ self.credentials = {}
3367
  # Initialize the token bucket rate limiter
3368
  self._rate_limiter = AsyncTokenBucket(
3369
  rate=self.max_requests_per_second,
 
3479
  user requests.
3480
 
3481
  Current _supported_apis = ["watsonx", "together-ai", "open-ai", "aws", "ollama",
3482
+ "bam", "watsonx-sdk", "rits", "vertex-ai","hf-local"]
3483
 
3484
  Args:
3485
  provider (Optional):
 
3686
  "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
3687
  "gpt-4-1": "replicate/openai/gpt-4.1",
3688
  },
3689
+ "hf-local": {
3690
+ "granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
3691
+ "llama-3-3-8b-instruct": "meta-llama/Llama-3.3-8B-Instruct",
3692
+ "SmolLM2-1.7B-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
3693
+ },
3694
  }
3695
  provider_model_map["watsonx"] = {
3696
  k: f"watsonx/{v}" for k, v in provider_model_map["watsonx-sdk"].items()
 
3708
  "azure": LiteLLMInferenceEngine,
3709
  "vertex-ai": LiteLLMInferenceEngine,
3710
  "replicate": LiteLLMInferenceEngine,
3711
+ "hf-local": HFAutoModelInferenceEngine,
3712
  }
3713
 
3714
  _provider_param_renaming = {
3715
  "bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
3716
  "watsonx-sdk": {"model": "model_name"},
3717
  "rits": {"model": "model_name"},
3718
+ "hf-local": {"model": "model_name"},
3719
  }
3720
 
3721
  def get_return_object(self, **kwargs):
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.26.1"
 
1
+ version = "1.26.2"