Tiago Caldeira commited on
Commit
3d582fc
Β·
1 Parent(s): 9f37a6e

different approach using unsloth model

Browse files
Files changed (1) hide show
  1. app.py +24 -29
app.py CHANGED
@@ -1,28 +1,27 @@
1
  import torch
2
  import gradio as gr
3
- from unsloth import FastModel
4
- from transformers import TextStreamer, AutoTokenizer
5
  import textwrap
6
 
7
- # Load model (4-bit quantized)
8
- model, tokenizer = FastModel.from_pretrained(
9
- model_name = "unsloth/gemma-3n-E4B-it",
10
- dtype = None, # Auto-detect FP16/32
11
- max_seq_length = 1024,
12
- load_in_4bit = True,
13
- full_finetuning = False,
14
- # token = "hf_..." # Uncomment if model is gated
 
 
15
  )
16
 
17
  model.eval()
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model.to(device)
20
 
21
- # πŸ› οΈ Format output
22
  def print_response(text: str) -> str:
23
  return "\n".join(textwrap.fill(line, 100) for line in text.split("\n"))
24
 
25
- # πŸ” Inference function for Gradio
26
  def predict_text(system_prompt: str, user_prompt: str) -> str:
27
  messages = [
28
  {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]},
@@ -34,23 +33,24 @@ def predict_text(system_prompt: str, user_prompt: str) -> str:
34
  add_generation_prompt=True,
35
  tokenize=True,
36
  return_dict=True,
37
- return_tensors="pt",
38
- ).to(device)
 
 
39
 
40
  with torch.inference_mode():
41
- outputs = model.generate(
42
  **inputs,
43
- max_new_tokens=256,
44
- temperature=1.0,
45
- top_p=0.95,
46
- top_k=64,
47
  )
48
 
49
- generated = outputs[0][inputs["input_ids"].shape[-1]:]
50
  decoded = tokenizer.decode(generated, skip_special_tokens=True)
51
  return print_response(decoded)
52
 
53
- # πŸŽ›οΈ Gradio UI
54
  demo = gr.Interface(
55
  fn=predict_text,
56
  inputs=[
@@ -58,10 +58,5 @@ demo = gr.Interface(
58
  gr.Textbox(lines=4, label="User Prompt", placeholder="Ask something..."),
59
  ],
60
  outputs=gr.Textbox(label="Gemma 3n Response"),
61
- title="Gemma 3n Text-Only Chat",
62
- description="Interact with the Gemma 3n language model using plain text. 4-bit quantized for efficiency.",
63
- )
64
-
65
- if __name__ == "__main__":
66
- demo.launch()
67
 
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
 
4
  import textwrap
5
 
6
+ model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit"
7
+
8
+ # Load tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+
11
+ # Load model in full precision on CPU β€” no bitsandbytes
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ device_map="cpu", # Force CPU
15
+ torch_dtype=torch.float32, # Use FP32 to ensure CPU compatibility
16
  )
17
 
18
  model.eval()
 
 
19
 
20
+ # Helper to format response nicely
21
  def print_response(text: str) -> str:
22
  return "\n".join(textwrap.fill(line, 100) for line in text.split("\n"))
23
 
24
+ # Inference function for Gradio
25
  def predict_text(system_prompt: str, user_prompt: str) -> str:
26
  messages = [
27
  {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]},
 
33
  add_generation_prompt=True,
34
  tokenize=True,
35
  return_dict=True,
36
+ return_tensors="pt"
37
+ ).to("cpu")
38
+
39
+ input_len = inputs["input_ids"].shape[-1]
40
 
41
  with torch.inference_mode():
42
+ output = model.generate(
43
  **inputs,
44
+ max_new_tokens=300,
45
+ do_sample=False,
46
+ use_cache=False # Important for CPU compatibility
 
47
  )
48
 
49
+ generated = output[0][input_len:]
50
  decoded = tokenizer.decode(generated, skip_special_tokens=True)
51
  return print_response(decoded)
52
 
53
+ # Gradio UI
54
  demo = gr.Interface(
55
  fn=predict_text,
56
  inputs=[
 
58
  gr.Textbox(lines=4, label="User Prompt", placeholder="Ask something..."),
59
  ],
60
  outputs=gr.Textbox(label="Gemma 3n Response"),
61
+ title="Gemma 3n Chat (CPU-friendly
 
 
 
 
 
62