bobpopboom commited on
Commit
8c068ee
·
verified ·
1 Parent(s): b84cd4b

ok ill do it myself then

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -5,14 +5,29 @@ import torch
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
8
- tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
11
- load_in_8bit=True,
12
- device_map="auto",
13
- torch_dtype=torch.float16
 
 
14
  )
15
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  except Exception as e:
17
  print(f"Error loading model: {e}")
18
  exit()
@@ -22,8 +37,6 @@ def respond(
22
  history,
23
  system_message,
24
  max_tokens,
25
- temperature,
26
- top_p,
27
  ):
28
  # Construct the prompt with clear separation
29
  prompt = f"{system_message}\n"
@@ -35,8 +48,7 @@ def respond(
35
  response = pipe(
36
  prompt,
37
  max_new_tokens=max_tokens,
38
- temperature=temperature,
39
- top_p=top_p,
40
  eos_token_id=tokenizer.eos_token_id, # Use EOS token to stop generation
41
  )[0]["generated_text"]
42
 
@@ -55,14 +67,6 @@ demo = gr.ChatInterface(
55
  label="System message",
56
  ),
57
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
58
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
- gr.Slider(
60
- minimum=0.1,
61
- maximum=1.0,
62
- value=0.95,
63
- step=0.05,
64
- label="Top-p (nucleus sampling)",
65
- ),
66
  ],
67
  )
68
 
 
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
+ device_map="cpu",
11
+ torch_dtype=torch.float16,
12
+ low_cpu_mem_usage=True,
13
+ max_memory={"cpu": "15GB"},
14
+ offload_folder="offload",
15
  )
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ tokenizer.model_max_length = 512 # Set maximum length
19
+
20
+ pipe = pipeline(
21
+ "text-generation",
22
+ model=model,
23
+ tokenizer=tokenizer,
24
+ torch_dtype=torch.float16,
25
+ num_return_sequences=1,
26
+ do_sample=False,
27
+ truncation=True,
28
+ max_new_tokens=128
29
+ )
30
+
31
  except Exception as e:
32
  print(f"Error loading model: {e}")
33
  exit()
 
37
  history,
38
  system_message,
39
  max_tokens,
 
 
40
  ):
41
  # Construct the prompt with clear separation
42
  prompt = f"{system_message}\n"
 
48
  response = pipe(
49
  prompt,
50
  max_new_tokens=max_tokens,
51
+ do_sample=False,
 
52
  eos_token_id=tokenizer.eos_token_id, # Use EOS token to stop generation
53
  )[0]["generated_text"]
54
 
 
67
  label="System message",
68
  ),
69
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
 
 
 
 
 
 
 
 
70
  ],
71
  )
72