cpg716 commited on
Commit
b604130
·
verified ·
1 Parent(s): e85540a

Create test_llama4.py

Browse files
Files changed (1) hide show
  1. test_llama4.py +158 -0
test_llama4.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, Llama4ForConditionalGeneration
3
+ import time
4
+ import os
5
+ from huggingface_hub import login
6
+ import requests
7
+ from PIL import Image
8
+ from io import BytesIO
9
+
10
+ # Print versions for debugging
11
+ import sys
12
+ print(f"Python version: {sys.version}")
13
+ print(f"PyTorch version: {torch.__version__}")
14
+ import transformers
15
+ print(f"Transformers version: {transformers.__version__}")
16
+
17
+ # Get token from environment
18
+ token = os.environ.get("HUGGINGFACE_TOKEN", "")
19
+ if token:
20
+ print(f"Token found: {token[:5]}...")
21
+ else:
22
+ print("No token found in environment variables!")
23
+
24
+ # Login to Hugging Face
25
+ try:
26
+ login(token=token)
27
+ print("Successfully logged in to Hugging Face Hub")
28
+ except Exception as e:
29
+ print(f"Error logging in: {e}")
30
+
31
+ # Test 1: Simple text generation with Llama 4
32
+ def test_text_generation():
33
+ print("\n=== Testing Text Generation ===")
34
+ try:
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer
36
+
37
+ model_id = "meta-llama/Llama-4-8B-Instruct" # Using smaller model for faster testing
38
+
39
+ print(f"Loading tokenizer from {model_id}...")
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
41
+
42
+ print(f"Loading model from {model_id}...")
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_id,
45
+ token=token,
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="auto"
48
+ )
49
+
50
+ print("Model and tokenizer loaded successfully!")
51
+
52
+ # Simple prompt
53
+ prompt = "Write a short poem about artificial intelligence."
54
+
55
+ print(f"Generating text for prompt: '{prompt}'")
56
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
57
+
58
+ start_time = time.time()
59
+ outputs = model.generate(**inputs, max_new_tokens=100)
60
+ end_time = time.time()
61
+
62
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ print(f"Generation completed in {end_time - start_time:.2f} seconds")
65
+ print(f"Result: {result}")
66
+ return True
67
+ except Exception as e:
68
+ print(f"Error in text generation test: {e}")
69
+ import traceback
70
+ print(traceback.format_exc())
71
+ return False
72
+
73
+ # Test 2: Image-text generation with Llama 4 Scout
74
+ def test_image_text_generation():
75
+ print("\n=== Testing Image-Text Generation ===")
76
+ try:
77
+ model_id = "meta-llama/Llama-4-Scout-8B-16E-Instruct" # Using smaller model for faster testing
78
+
79
+ print(f"Loading processor from {model_id}...")
80
+ processor = AutoProcessor.from_pretrained(model_id, token=token)
81
+
82
+ print(f"Loading model from {model_id}...")
83
+ model = Llama4ForConditionalGeneration.from_pretrained(
84
+ model_id,
85
+ token=token,
86
+ torch_dtype=torch.bfloat16,
87
+ device_map="auto"
88
+ )
89
+
90
+ print("Model and processor loaded successfully!")
91
+
92
+ # Load a test image
93
+ print("Loading test image...")
94
+ response = requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
95
+ img = Image.open(BytesIO(response.content))
96
+ print(f"Image loaded: {img.size}")
97
+
98
+ # Simple prompt
99
+ prompt = "Describe this image in two sentences."
100
+
101
+ print(f"Creating messages with prompt: '{prompt}'")
102
+ messages = [
103
+ {
104
+ "role": "user",
105
+ "content": [
106
+ {"type": "image", "url": "data:image/jpeg;base64," + BytesIO(response.content).getvalue().hex()},
107
+ {"type": "text", "text": prompt},
108
+ ]
109
+ },
110
+ ]
111
+
112
+ print("Applying chat template...")
113
+ inputs = processor.apply_chat_template(
114
+ messages,
115
+ add_generation_prompt=True,
116
+ tokenize=True,
117
+ return_dict=True,
118
+ return_tensors="pt",
119
+ ).to(model.device)
120
+
121
+ print("Generating response...")
122
+ start_time = time.time()
123
+ outputs = model.generate(**inputs, max_new_tokens=100)
124
+ end_time = time.time()
125
+
126
+ result = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
127
+
128
+ print(f"Generation completed in {end_time - start_time:.2f} seconds")
129
+ print(f"Result: {result}")
130
+ return True
131
+ except Exception as e:
132
+ print(f"Error in image-text generation test: {e}")
133
+ import traceback
134
+ print(traceback.format_exc())
135
+ return False
136
+
137
+ if __name__ == "__main__":
138
+ print("Starting Llama 4 tests...")
139
+
140
+ # Run text generation test
141
+ text_success = test_text_generation()
142
+
143
+ # Run image-text generation test if text test succeeds
144
+ if text_success:
145
+ image_text_success = test_image_text_generation()
146
+ else:
147
+ print("Skipping image-text test due to text test failure")
148
+ image_text_success = False
149
+
150
+ # Summary
151
+ print("\n=== Test Summary ===")
152
+ print(f"Text Generation Test: {'SUCCESS' if text_success else 'FAILED'}")
153
+ print(f"Image-Text Generation Test: {'SUCCESS' if image_text_success else 'FAILED'}")
154
+
155
+ if text_success and image_text_success:
156
+ print("\nAll tests passed! Your Llama 4 Scout setup is working correctly.")
157
+ else:
158
+ print("\nSome tests failed. Please check the error messages above.")