jefalod commited on
Commit
90f142b
·
verified ·
1 Parent(s): 582d7ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -4
app.py CHANGED
@@ -1,11 +1,91 @@
1
  # app.py
2
 
3
  import gradio as gr
4
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
5
 
6
- model = AutoModelForCausalLM.from_pretrained("trained_model", device_map="auto")
7
- tokenizer = AutoTokenizer.from_pretrained("trained_model")
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
10
 
11
  def chatbot(instruction):
@@ -17,5 +97,5 @@ gr.Interface(
17
  fn=chatbot,
18
  inputs="text",
19
  outputs="text",
20
- title="TinyLlama QLoRA Support Bot"
21
  ).launch()
 
1
  # app.py
2
 
3
  import gradio as gr
4
+ import torch
5
+ from datasets import load_dataset
6
+ from transformers import (
7
+ AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig, pipeline
8
+ )
9
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
10
 
11
+ # Load dataset
12
+ dataset = load_dataset(
13
+ "json",
14
+ data_files="https://huggingface.co/datasets/bitext/Bitext-customer-support-llm-chatbot-training-dataset/resolve/main/bitext_customer_support.jsonl",
15
+ split="train[:100]" # Keep it small to avoid timeouts
16
+ )
17
 
18
+ def format(example):
19
+ return {
20
+ "text": f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"
21
+ }
22
+
23
+ dataset = dataset.map(format)
24
+
25
+ # Tokenizer
26
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ def tokenize(example):
31
+ tokens = tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)
32
+ tokens["labels"] = tokens["input_ids"].copy()
33
+ return tokens
34
+
35
+ tokenized_dataset = dataset.map(tokenize, batched=True)
36
+
37
+ # QLoRA setup
38
+ bnb_config = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type="nf4",
42
+ bnb_4bit_compute_dtype=torch.float16
43
+ )
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ device_map="auto",
48
+ quantization_config=bnb_config
49
+ )
50
+
51
+ model.gradient_checkpointing_enable()
52
+ model = prepare_model_for_kbit_training(model)
53
+
54
+ lora_config = LoraConfig(
55
+ r=8,
56
+ lora_alpha=32,
57
+ lora_dropout=0.05,
58
+ bias="none",
59
+ target_modules=["q_proj", "v_proj"],
60
+ task_type="CAUSAL_LM"
61
+ )
62
+
63
+ model = get_peft_model(model, lora_config)
64
+
65
+ # Training
66
+ training_args = TrainingArguments(
67
+ output_dir="trained_model",
68
+ per_device_train_batch_size=2,
69
+ gradient_accumulation_steps=4,
70
+ learning_rate=2e-4,
71
+ num_train_epochs=1,
72
+ logging_dir="./logs",
73
+ save_strategy="no",
74
+ bf16=True,
75
+ report_to="none",
76
+ optim="paged_adamw_8bit"
77
+ )
78
+
79
+ trainer = Trainer(
80
+ model=model,
81
+ args=training_args,
82
+ train_dataset=tokenized_dataset,
83
+ tokenizer=tokenizer
84
+ )
85
+
86
+ trainer.train()
87
+
88
+ # Inference pipeline
89
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
90
 
91
  def chatbot(instruction):
 
97
  fn=chatbot,
98
  inputs="text",
99
  outputs="text",
100
+ title="Fine-Tuned TinyLlama Bitext Chatbot"
101
  ).launch()