mknolan commited on
Commit
920f22f
·
verified ·
1 Parent(s): 8325f5a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import traceback
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from transformers.generation import GenerationConfig
9
+
10
+ print("=" * 50)
11
+ print("InternVL2-8B IMAGE & TEXT ANALYSIS")
12
+ print("=" * 50)
13
+
14
+ # System information
15
+ print(f"Python version: {sys.version}")
16
+ print(f"PyTorch version: {torch.__version__}")
17
+ print(f"CUDA available: {torch.cuda.is_available()}")
18
+
19
+ if torch.cuda.is_available():
20
+ print(f"CUDA version: {torch.version.cuda}")
21
+ print(f"GPU count: {torch.cuda.device_count()}")
22
+ for i in range(torch.cuda.device_count()):
23
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
24
+
25
+ # Memory info
26
+ print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
27
+ print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
28
+ print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
29
+ else:
30
+ print("CUDA is not available. This application requires GPU acceleration.")
31
+
32
+ # Create a function to load the model
33
+ def load_model():
34
+ try:
35
+ print("\nLoading InternVL2-8B model...")
36
+
37
+ # Create a fake flash_attn module to avoid dependency errors
38
+ import sys
39
+ import types
40
+ if "flash_attn" not in sys.modules:
41
+ flash_attn_module = types.ModuleType("flash_attn")
42
+ flash_attn_module.__version__ = "0.0.0-disabled"
43
+ sys.modules["flash_attn"] = flash_attn_module
44
+ print("Created dummy flash_attn module to avoid dependency error")
45
+
46
+ # Load the model and tokenizer
47
+ model_path = "OpenGVLab/InternVL2-8B"
48
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_path,
51
+ torch_dtype=torch.bfloat16,
52
+ device_map="auto",
53
+ trust_remote_code=True
54
+ )
55
+
56
+ # Define generation config
57
+ generation_config = GenerationConfig(
58
+ max_new_tokens=512,
59
+ do_sample=True,
60
+ temperature=0.7,
61
+ top_p=0.8,
62
+ repetition_penalty=1.0
63
+ )
64
+
65
+ print("✓ Model and tokenizer loaded successfully!")
66
+ return model, tokenizer, generation_config
67
+
68
+ except Exception as e:
69
+ print(f"\n❌ ERROR loading model: {str(e)}")
70
+ traceback.print_exc()
71
+ return None, None, None
72
+
73
+ # Helper function to load and process an image
74
+ def load_image(image_path, processor=None):
75
+ """Load an image and prepare it for the model."""
76
+ if isinstance(image_path, str):
77
+ image = Image.open(image_path).convert('RGB')
78
+ else:
79
+ image = image_path
80
+
81
+ # The model handles image processing internally
82
+ return image
83
+
84
+ # Function to analyze an image with text
85
+ def analyze_image(model, tokenizer, image, prompt, generation_config):
86
+ try:
87
+ # Process the conversation
88
+ messages = [
89
+ {"role": "user", "content": f"{prompt}", "image": image}
90
+ ]
91
+
92
+ # Generate a response
93
+ response = model.chat(tokenizer, messages=messages, generation_config=generation_config)
94
+ return response
95
+
96
+ except Exception as e:
97
+ error_msg = f"Error analyzing image: {str(e)}"
98
+ traceback.print_exc()
99
+ return error_msg
100
+
101
+ # Create the Gradio interface
102
+ def create_interface():
103
+ # Load model at startup
104
+ model, tokenizer, generation_config = load_model()
105
+
106
+ if model is None:
107
+ # If model loading failed, create a simple error interface
108
+ with gr.Blocks(title="InternVL2 Chat - Error") as demo:
109
+ gr.Markdown("# ❌ Error: Failed to load models")
110
+ gr.Markdown("Please check the console for error details.")
111
+ return demo
112
+
113
+ # Predefined prompts for analysis
114
+ prompts = [
115
+ "Describe this image in detail.",
116
+ "What text appears in this image? Please read and transcribe it accurately.",
117
+ "Analyze the content of this image, including any text, pictures, and their relationships.",
118
+ "What is the main subject of this image?",
119
+ "Is there any text in this image? If so, what does it say?",
120
+ "Describe the layout and visual elements of this document.",
121
+ "Summarize the key information presented in this image."
122
+ ]
123
+
124
+ # Create the full interface
125
+ with gr.Blocks(title="InternVL2 Image Analysis") as demo:
126
+ gr.Markdown("# 🖼️ InternVL2-8B Image & Text Analyzer")
127
+ gr.Markdown("### Upload an image to analyze its visual content and text")
128
+
129
+ with gr.Row():
130
+ with gr.Column(scale=1):
131
+ input_image = gr.Image(type="pil", label="Upload Image")
132
+ prompt_input = gr.Dropdown(
133
+ choices=prompts,
134
+ value=prompts[0],
135
+ label="Select a prompt or enter your own below",
136
+ allow_custom_value=True
137
+ )
138
+ custom_prompt = gr.Textbox(label="Custom prompt", placeholder="Enter your custom prompt here...")
139
+ analyze_btn = gr.Button("Analyze Image", variant="primary")
140
+
141
+ with gr.Column(scale=1):
142
+ output = gr.Textbox(label="Analysis Results", lines=15)
143
+
144
+ # Example images
145
+ gr.Examples(
146
+ examples=[
147
+ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/blip-image-demo.png", "What's in this image?"],
148
+ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/assets/130_vision_language_pretraining/fig_vision_language.jpg", "Describe this diagram in detail."],
149
+ ],
150
+ inputs=[input_image, custom_prompt],
151
+ )
152
+
153
+ # When prompt dropdown changes, update custom prompt
154
+ prompt_input.change(fn=lambda x: x, inputs=prompt_input, outputs=custom_prompt)
155
+
156
+ # Set up the click event for analysis
157
+ def on_analyze_click(image, prompt_text):
158
+ if image is None:
159
+ return "Please upload an image first."
160
+
161
+ # Use either the dropdown selection or custom prompt
162
+ final_prompt = prompt_text if prompt_text.strip() else prompt_input
163
+
164
+ result = analyze_image(model, tokenizer, image, final_prompt, generation_config)
165
+ return result
166
+
167
+ analyze_btn.click(
168
+ fn=on_analyze_click,
169
+ inputs=[input_image, custom_prompt],
170
+ outputs=output
171
+ )
172
+
173
+ return demo
174
+
175
+ # Main function
176
+ if __name__ == "__main__":
177
+ # Set environment variable for better GPU memory management
178
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
179
+
180
+ # Create and launch the interface
181
+ demo = create_interface()
182
+ demo.launch(share=False, server_name="0.0.0.0")