Aekanun commited on
Commit
43d09a2
·
1 Parent(s): b31bef1
Files changed (2) hide show
  1. app.py +100 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,103 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
4
+ from PIL import Image
5
 
6
+ # Global variables for model and processor
7
+ model = None
8
+ processor = None
9
 
10
+ def load_model_and_processor():
11
+ global model, processor
12
+
13
+ model_path = "Aekanun/thai-handwriting-llm"
14
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
15
+
16
+ # BitsAndBytes config for 4-bit quantization
17
+ bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.bfloat16
22
+ )
23
+
24
+ try:
25
+ # Load processor from base model
26
+ processor = AutoProcessor.from_pretrained(base_model_path)
27
+
28
+ # Load fine-tuned model
29
+ model = AutoModelForVision2Seq.from_pretrained(
30
+ model_path,
31
+ device_map="auto",
32
+ torch_dtype=torch.bfloat16,
33
+ quantization_config=bnb_config
34
+ )
35
+ return True
36
+ except Exception as e:
37
+ print(f"Error loading model: {str(e)}")
38
+ return False
39
+
40
+ def process_handwriting(image):
41
+ global model, processor
42
+
43
+ if image is None:
44
+ return "กรุณาอัพโหลดรูปภาพ"
45
+
46
+ try:
47
+ # Ensure image is in PIL format
48
+ if not isinstance(image, Image.Image):
49
+ image = Image.fromarray(image)
50
+
51
+ # Prepare prompt and messages
52
+ prompt = """Transcribe the Thai handwritten text from the provided image.
53
+ Only return the transcription in Thai language."""
54
+ messages = [
55
+ {
56
+ "role": "user",
57
+ "content": [
58
+ {"type": "text", "text": prompt},
59
+ {"type": "image", "image": image}
60
+ ],
61
+ }
62
+ ]
63
+
64
+ # Process input
65
+ text = processor.apply_chat_template(messages, tokenize=False)
66
+ inputs = processor(text=text, images=image, return_tensors="pt")
67
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
68
+
69
+ # Generate output
70
+ with torch.no_grad():
71
+ outputs = model.generate(
72
+ **inputs,
73
+ max_new_tokens=256,
74
+ do_sample=False,
75
+ pad_token_id=processor.tokenizer.pad_token_id
76
+ )
77
+
78
+ # Decode output
79
+ transcription = processor.decode(outputs[0], skip_special_tokens=True)
80
+ return transcription
81
+
82
+ except Exception as e:
83
+ return f"เกิดข้อผิดพลาด: {str(e)}"
84
+
85
+ # Load model when starting
86
+ print("กำลังโหลดโมเดล...")
87
+ model_loaded = load_model_and_processor()
88
+
89
+ if model_loaded:
90
+ # Create Gradio interface
91
+ demo = gr.Interface(
92
+ fn=process_handwriting,
93
+ inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
94
+ outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
95
+ title="Thai Handwriting to Text ด้วย LLaMA Vision",
96
+ description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ โดยใช้โมเดล LLaMA Vision ที่ fine-tune มาสำหรับภาษาไทย",
97
+ examples=[["example1.jpg"], ["example2.jpg"]]
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch(share=True)
102
+ else:
103
+ print("ไม่สามารถโหลดโมเดลได้ กรุณาตรวจสอบ log")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.13.0
2
+ torch>=2.0.0
3
+ transformers>=4.34.0
4
+ Pillow>=9.0.0
5
+ bitsandbytes>=0.41.1
6
+ accelerate>=0.24.1