root commited on
Commit
56ce0a4
·
1 Parent(s): 5473fae

remove some comments in sr_tp_modeling.py

Browse files
Files changed (1) hide show
  1. sr_tp_modeling.py +3 -5
sr_tp_modeling.py CHANGED
@@ -831,12 +831,10 @@ class SRV1ForCausalLM(SRV1PreTrainedModel):
831
  return reordered_past
832
 
833
  class SRV1ForCausalLMParallel(SRV1ForCausalLM):
834
- # def __init__(self, model_id:str, revision: Optional[str] = None,
835
- # quantize: Optional[str] = None,
836
- # dtype: Optional[torch.dtype] = None,
837
- # trust_remote_code: bool = False):
838
  def __init__(self, config, **kwargs):
839
- model_id = kwargs.get("pretrained_model_name_or_path", None)
 
 
840
  revision = kwargs.get("revision", None)
841
  trust_remote_code = kwargs.get("trust_remote_code", False)
842
  quantize = kwargs.get("quantize", None)
 
831
  return reordered_past
832
 
833
  class SRV1ForCausalLMParallel(SRV1ForCausalLM):
 
 
 
 
834
  def __init__(self, config, **kwargs):
835
+ model_id = kwargs.get("local_path", None)
836
+ if model_id is None:
837
+ model_id = kwargs.get("pretrained_model_name_or_path", None)
838
  revision = kwargs.get("revision", None)
839
  trust_remote_code = kwargs.get("trust_remote_code", False)
840
  quantize = kwargs.get("quantize", None)