Apollo_GenAI / app.py
VishalD1234's picture
Update app.py
852d6ee verified
raw
history blame
5.08 kB
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import torchvision.transforms as transforms
# Model configuration
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
# Load Model and Tokenizer
def load_model():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=TORCH_TYPE,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="auto"
).eval()
return model, tokenizer
model, tokenizer = load_model()
# Delay Reasons for Each Manufacturing Step
DELAY_REASONS = {
"Step 1": ["Delay in Bead Insertion", "Lack of raw material"],
"Step 2": ["Inner Liner Adjustment by Technician", "Person rebuilding defective Tire Sections"],
"Step 3": ["Manual Adjustment in Ply1 apply", "Technician repairing defective Tire Sections"],
"Step 4": ["Delay in Bead set", "Lack of raw material"],
"Step 5": ["Delay in Turnup", "Lack of raw material"],
"Step 6": ["Person Repairing sidewall", "Person rebuilding defective Tire Sections"],
"Step 7": ["Delay in sidewall stitching", "Lack of raw material"],
"Step 8": ["No person available to load Carcass", "No person available to collect tire"]
}
def load_image(image_data):
"""Preprocess the input image for model compatibility."""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = Image.open(image_data).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE, dtype=TORCH_TYPE)
return image_tensor
def get_analysis_prompt(step_number):
"""Generates the analysis prompt for the given step."""
delay_reasons = DELAY_REASONS.get(step_number, [])
prompt = f"""
You are an AI expert analyzing tire manufacturing steps.
This is Step {step_number}. Identify the most likely cause of delay based on visual evidence.
Possible Delay Reasons: {', '.join(delay_reasons)}
Provide the reason and supporting evidence.
"""
return prompt
def predict(prompt, image_tensor, temperature=0.3):
"""Generates predictions based on the image and textual prompt."""
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=prompt,
images=[image_tensor],
history=[],
template_version='chat'
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[inputs['images'][0]]],
}
gen_kwargs = {
"max_new_tokens": 1024,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def inference(image, step_number):
"""Handles the inference process."""
try:
if not image:
return "Please upload an image."
image_tensor = load_image(image)
prompt = get_analysis_prompt(step_number)
response = predict(prompt, image_tensor)
return response
except Exception as e:
return f"An error occurred during analysis: {str(e)}"
# Gradio Interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("""
# Manufacturing Step Analysis System (Image Input)
Upload an image and select the manufacturing step to analyze potential delays.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Manufacturing Image", type="file")
step_number = gr.Dropdown(
choices=[f"Step {i}" for i in range(1, 9)],
label="Manufacturing Step"
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output = gr.Textbox(label="Analysis Result", lines=10)
analyze_btn.click(
fn=inference,
inputs=[image_input, step_number],
outputs=[output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue().launch(share=True)