bobpopboom commited on
Commit
fd1d420
·
verified ·
1 Parent(s): 3f3da62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
11
- load_in_8bit=True, # Load in 8-bit quantization
12
  device_map="auto", #Use GPU if available
13
- torch_dtype=torch.float16 #Use float 16 for additional memory reduction
14
  )
15
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
 
5
  model_id = "thrishala/mental_health_chatbot"
6
 
7
  try:
8
+ quantization_config = BitsAndBytesConfig(
9
+ load_in_4bit=True, # we going to 4 babey
10
+ )
11
  tokenizer = AutoTokenizer.from_pretrained(model_id)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
+ quantization_config=quantization_config,
15
  device_map="auto", #Use GPU if available
 
16
  )
17
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
18