Update app.py
Browse files
app.py
CHANGED
|
@@ -91,19 +91,20 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
| 91 |
|
| 92 |
## main
|
| 93 |
torch.set_grad_enabled(False)
|
| 94 |
-
model_name = '
|
| 95 |
|
| 96 |
# extract model info
|
| 97 |
model_args = deepcopy(model_info[model_name])
|
|
|
|
| 98 |
original_prompt_template = model_args.pop('original_prompt_template')
|
| 99 |
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
| 100 |
-
|
| 101 |
use_ctransformers = model_args.pop('ctransformers', False)
|
| 102 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
| 103 |
|
| 104 |
# get model
|
| 105 |
-
model = AutoModelClass.from_pretrained(
|
| 106 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 107 |
|
| 108 |
# demo
|
| 109 |
json_output = gr.JSON()
|
|
|
|
| 91 |
|
| 92 |
## main
|
| 93 |
torch.set_grad_enabled(False)
|
| 94 |
+
model_name = 'LLAMA2-7B'
|
| 95 |
|
| 96 |
# extract model info
|
| 97 |
model_args = deepcopy(model_info[model_name])
|
| 98 |
+
model_path = model_args.pop('model_path')
|
| 99 |
original_prompt_template = model_args.pop('original_prompt_template')
|
| 100 |
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
| 101 |
+
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
| 102 |
use_ctransformers = model_args.pop('ctransformers', False)
|
| 103 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
| 104 |
|
| 105 |
# get model
|
| 106 |
+
model = AutoModelClass.from_pretrained(model_path, **model_args)
|
| 107 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 108 |
|
| 109 |
# demo
|
| 110 |
json_output = gr.JSON()
|