root
commited on
Commit
·
56ce0a4
1
Parent(s):
5473fae
remove some comments in sr_tp_modeling.py
Browse files- sr_tp_modeling.py +3 -5
sr_tp_modeling.py
CHANGED
@@ -831,12 +831,10 @@ class SRV1ForCausalLM(SRV1PreTrainedModel):
|
|
831 |
return reordered_past
|
832 |
|
833 |
class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
834 |
-
# def __init__(self, model_id:str, revision: Optional[str] = None,
|
835 |
-
# quantize: Optional[str] = None,
|
836 |
-
# dtype: Optional[torch.dtype] = None,
|
837 |
-
# trust_remote_code: bool = False):
|
838 |
def __init__(self, config, **kwargs):
|
839 |
-
model_id = kwargs.get("
|
|
|
|
|
840 |
revision = kwargs.get("revision", None)
|
841 |
trust_remote_code = kwargs.get("trust_remote_code", False)
|
842 |
quantize = kwargs.get("quantize", None)
|
|
|
831 |
return reordered_past
|
832 |
|
833 |
class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
|
|
|
|
|
|
|
|
834 |
def __init__(self, config, **kwargs):
|
835 |
+
model_id = kwargs.get("local_path", None)
|
836 |
+
if model_id is None:
|
837 |
+
model_id = kwargs.get("pretrained_model_name_or_path", None)
|
838 |
revision = kwargs.get("revision", None)
|
839 |
trust_remote_code = kwargs.get("trust_remote_code", False)
|
840 |
quantize = kwargs.get("quantize", None)
|