ahmed-masry commited on
Commit
b9672fb
·
verified ·
1 Parent(s): 4439944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -36,10 +36,9 @@ if not torch.cuda.is_available():
36
 
37
 
38
  if torch.cuda.is_available():
39
- model_id = "meta-llama/Llama-2-13b-chat-hf"
40
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
41
  tokenizer = AutoTokenizer.from_pretrained(model_id)
42
- tokenizer.use_default_system_prompt = False
43
 
44
 
45
  @spaces.GPU
@@ -49,7 +48,7 @@ def generate(
49
  system_prompt: str = "",
50
  max_new_tokens: int = 1024,
51
  temperature: float = 0.6,
52
- top_p: float = 0.9,
53
  top_k: int = 50,
54
  repetition_penalty: float = 1.2,
55
  ) -> Iterator[str]:
@@ -59,7 +58,10 @@ def generate(
59
  conversation += chat_history
60
  conversation.append({"role": "user", "content": message})
61
 
62
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
63
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
64
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
65
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
36
 
37
 
38
  if torch.cuda.is_available():
39
+ model_id = "ALLaM-AI/ALLaM-7B-Instruct-preview"
40
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
41
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
42
 
43
 
44
  @spaces.GPU
 
48
  system_prompt: str = "",
49
  max_new_tokens: int = 1024,
50
  temperature: float = 0.6,
51
+ top_p: float = 0.95,
52
  top_k: int = 50,
53
  repetition_penalty: float = 1.2,
54
  ) -> Iterator[str]:
 
58
  conversation += chat_history
59
  conversation.append({"role": "user", "content": message})
60
 
61
+ inputs = tokenizer.apply_chat_template(conversation, tokenize=False)
62
+ inputs = tokenizer(inputs, return_tensors='pt', return_token_type_ids=False)
63
+
64
+ # input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
65
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
66
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
67
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")