yunusajib commited on
Commit
932d067
·
verified ·
1 Parent(s): 971be40

update app and model

Browse files
Files changed (2) hide show
  1. app.py +23 -4
  2. llava_inference.py +86 -60
app.py CHANGED
@@ -1,13 +1,31 @@
1
  import gradio as gr
2
  from PIL import Image
 
 
3
  from llava_inference import LLaVAHelper
4
 
5
- model = LLaVAHelper()
 
 
 
 
 
 
6
 
7
  def answer_question(image, question):
 
 
 
8
  if image is None or question.strip() == "":
9
  return "Please upload an image and enter a question."
10
- return model.generate_answer(image, question)
 
 
 
 
 
 
 
11
 
12
  demo = gr.Interface(
13
  fn=answer_question,
@@ -19,9 +37,10 @@ demo = gr.Interface(
19
  title="UK Public Transport Assistant",
20
  description="Upload an image of UK public transport signage (like train timetables or metro maps), and ask a question related to it. Powered by LLaVA-1.5.",
21
  examples=[
22
- ["assets/example.jpg", "Where is platform 3?"],
 
23
  ]
24
  )
25
 
26
  if __name__ == "__main__":
27
- demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
+ import os
4
+ import sys
5
  from llava_inference import LLaVAHelper
6
 
7
+ # Add error handling for module imports
8
+ try:
9
+ model = LLaVAHelper()
10
+ except Exception as e:
11
+ print(f"Failed to initialize LLaVA model: {e}")
12
+ # Continue execution to show error in the UI
13
+ model = None
14
 
15
  def answer_question(image, question):
16
+ if model is None:
17
+ return "Model initialization failed. Please check server logs."
18
+
19
  if image is None or question.strip() == "":
20
  return "Please upload an image and enter a question."
21
+
22
+ try:
23
+ return model.generate_answer(image, question)
24
+ except Exception as e:
25
+ return f"Error processing request: {str(e)}"
26
+
27
+ # Create examples directory if it doesn't exist
28
+ os.makedirs("assets", exist_ok=True)
29
 
30
  demo = gr.Interface(
31
  fn=answer_question,
 
37
  title="UK Public Transport Assistant",
38
  description="Upload an image of UK public transport signage (like train timetables or metro maps), and ask a question related to it. Powered by LLaVA-1.5.",
39
  examples=[
40
+ # Only use examples if the example file exists
41
+ ["assets/example.jpg", "Where is platform 3?"] if os.path.exists("assets/example.jpg") else None
42
  ]
43
  )
44
 
45
  if __name__ == "__main__":
46
+ demo.launch(share=True) # Added share=True to make it accessible on a public URL
llava_inference.py CHANGED
@@ -1,35 +1,58 @@
1
  from llava.model.builder import load_pretrained_model
2
  from llava.mm_utils import process_images, tokenizer_image_token
3
- from transformers import AutoTokenizer
4
  import torch
5
  import requests
6
  from PIL import Image
7
  from io import BytesIO
 
8
 
9
  class LLaVAHelper:
10
  def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"):
11
- # Use cache_dir to avoid issues with the default cache location
12
- # and disable force_download to use cached versions when available
13
- self.tokenizer = AutoTokenizer.from_pretrained(
14
- model_name,
15
- cache_dir="./model_cache",
16
- force_download=False,
17
- trust_remote_code=True
18
- )
19
 
20
- # Load model with same cache directory
21
- self.model, self.image_processor, _ = load_pretrained_model(
22
- model_name,
23
- None,
24
- cache_dir="./model_cache"
25
- )
26
- self.model.eval()
27
-
28
- # Move model to appropriate device
29
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
- self.model.to(self.device)
31
- print(f"Model loaded on {self.device}")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def generate_answer(self, image, question):
34
  """
35
  Generate a response to a question about an image
@@ -41,47 +64,50 @@ class LLaVAHelper:
41
  Returns:
42
  String response from the model
43
  """
44
- # Handle image input (either PIL Image or path/URL)
45
- if isinstance(image, str):
46
- if image.startswith(('http://', 'https://')):
47
- response = requests.get(image)
48
- image = Image.open(BytesIO(response.content))
49
- else:
50
- image = Image.open(image)
51
-
52
- # Preprocess image
53
- image_tensor = process_images(
54
- [image],
55
- self.image_processor,
56
- self.model.config
57
- )[0].unsqueeze(0).to(self.device)
58
-
59
- # Format prompt with question
60
- prompt = f"###Human: <image>\n{question}\n###Assistant:"
61
-
62
- # Tokenize prompt
63
- input_ids = tokenizer_image_token(
64
- prompt,
65
- self.tokenizer,
66
- return_tensors="pt"
67
- ).to(self.device)
68
-
69
- # Generate response
70
- with torch.no_grad():
71
- output_ids = self.model.generate(
72
- input_ids=input_ids.input_ids,
73
- images=image_tensor,
74
- max_new_tokens=512,
75
- do_sample=True,
76
- temperature=0.7,
77
- top_p=0.9,
78
- )
 
79
 
80
- # Decode and extract response
81
- output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
82
- return output.split("###Assistant:")[-1].strip()
 
 
83
 
84
- # Example usage
85
  if __name__ == "__main__":
86
  try:
87
  # Initialize model
 
1
  from llava.model.builder import load_pretrained_model
2
  from llava.mm_utils import process_images, tokenizer_image_token
3
+ from transformers import AutoTokenizer, AutoConfig
4
  import torch
5
  import requests
6
  from PIL import Image
7
  from io import BytesIO
8
+ import os
9
 
10
  class LLaVAHelper:
11
  def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"):
12
+ # Create cache directory if it doesn't exist
13
+ os.makedirs("./model_cache", exist_ok=True)
 
 
 
 
 
 
14
 
15
+ # First, try loading just the config to ensure the model is valid
16
+ try:
17
+ AutoConfig.from_pretrained(model_name)
18
+ except Exception as e:
19
+ print(f"Error loading model config: {e}")
20
+ # Try a different model version as fallback
21
+ model_name = "llava-hf/llava-1.5-13b-hf"
22
+ print(f"Trying alternative model: {model_name}")
 
 
 
 
23
 
24
+ try:
25
+ # Use specific tokenizer class to avoid issues
26
+ self.tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ cache_dir="./model_cache",
29
+ use_fast=False, # Use the Python implementation instead of the Rust one
30
+ legacy=True
31
+ )
32
+
33
+ # Load model with same cache directory and more explicit parameters
34
+ self.model, self.image_processor, _ = load_pretrained_model(
35
+ model_name,
36
+ None,
37
+ cache_dir="./model_cache",
38
+ load_8bit=False,
39
+ load_4bit=False,
40
+ device_map="auto"
41
+ )
42
+
43
+ self.model.eval()
44
+
45
+ # Move model to appropriate device
46
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ if self.device == "cpu":
48
+ # If using CPU, make sure model is in the right place
49
+ self.model = self.model.to(self.device)
50
+
51
+ print(f"Model loaded on {self.device}")
52
+ except Exception as e:
53
+ print(f"Detailed initialization error: {e}")
54
+ raise
55
+
56
  def generate_answer(self, image, question):
57
  """
58
  Generate a response to a question about an image
 
64
  Returns:
65
  String response from the model
66
  """
67
+ try:
68
+ # Handle image input (either PIL Image or path/URL)
69
+ if isinstance(image, str):
70
+ if image.startswith(('http://', 'https://')):
71
+ response = requests.get(image)
72
+ image = Image.open(BytesIO(response.content))
73
+ else:
74
+ image = Image.open(image)
75
+
76
+ # Preprocess image
77
+ image_tensor = process_images(
78
+ [image],
79
+ self.image_processor,
80
+ self.model.config
81
+ )[0].unsqueeze(0).to(self.device)
82
+
83
+ # Format prompt with question
84
+ prompt = f"###Human: <image>\n{question}\n###Assistant:"
85
+
86
+ # Tokenize prompt
87
+ input_ids = tokenizer_image_token(
88
+ prompt,
89
+ self.tokenizer,
90
+ return_tensors="pt"
91
+ ).to(self.device)
92
+
93
+ # Generate response
94
+ with torch.no_grad():
95
+ output_ids = self.model.generate(
96
+ input_ids=input_ids.input_ids,
97
+ images=image_tensor,
98
+ max_new_tokens=512,
99
+ do_sample=True,
100
+ temperature=0.7,
101
+ top_p=0.9,
102
+ )
103
 
104
+ # Decode and extract response
105
+ output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
106
+ return output.split("###Assistant:")[-1].strip()
107
+ except Exception as e:
108
+ return f"Error generating answer: {str(e)}"
109
 
110
+ # Example usage if __name__ == "__main__":
111
  if __name__ == "__main__":
112
  try:
113
  # Initialize model