Moditha24 commited on
Commit
538982d
Β·
verified Β·
1 Parent(s): 9b3963d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -62
app.py CHANGED
@@ -1,69 +1,19 @@
1
  import gradio as gr
2
- import torch
3
- import torchvision.transforms as transforms
4
- from PIL import Image
5
- from resnet import SupCEResNet # Ensure the correct import path
6
 
7
- # βœ… Define class labels (from Clothing1M)
8
- class_labels = [
9
- "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
10
- "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
11
- "Vest", "Underwear"
12
- ]
13
 
14
- # βœ… Function to load the model
15
- def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'):
16
- """Loads a self-supervised pretrained model for Clothing1M classification"""
17
- print(f"πŸ”„ Loading model from: {checkpoint_path}")
18
 
19
- # Load the checkpoint safely
20
- checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False)
 
21
 
22
- # Remove 'module.' prefix if using DataParallel
23
- state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
24
 
25
- # Initialize and load model
26
- model = SupCEResNet(net, num_classes=num_class, pool=True)
27
- model.load_state_dict(state_dict, strict=False)
28
 
29
- # Move model to GPU if available
30
- model = model.to("cuda" if torch.cuda.is_available() else "cpu")
31
- model.eval() # Set model to evaluation mode
32
-
33
- print("βœ… Model loaded successfully!")
34
- return model
35
-
36
- # βœ… Load the model once
37
- model = create_model_selfsup()
38
-
39
- # βœ… Define image preprocessing function
40
- def preprocess_image(image):
41
- """Transforms input image for the model"""
42
- transform = transforms.Compose([
43
- transforms.Resize(256),
44
- transforms.CenterCrop(224),
45
- transforms.ToTensor(),
46
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
- ])
48
- return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
49
-
50
- # βœ… Define inference function
51
- def predict_clothing(image):
52
- """Runs inference on an uploaded image"""
53
- image = Image.fromarray(image) # Convert numpy array to PIL Image
54
- image = preprocess_image(image) # Preprocess image
55
-
56
- with torch.no_grad():
57
- output = model(image)
58
- predicted_class = torch.argmax(output, dim=1).item() # Get class index
59
-
60
- return class_labels[predicted_class] # Return class name
61
-
62
- # βœ… Create Gradio Interface
63
- gr.Interface(
64
- fn=predict_clothing,
65
- inputs=gr.Image(type="numpy"),
66
- outputs=gr.Textbox(label="Predicted Clothing Type"),
67
- title="Clothing1M Classification",
68
- description="Upload an image to classify clothing into one of 14 categories."
69
- ).launch()
 
1
  import gradio as gr
2
+ from fastapi import FastAPI
3
+ from starlette.staticfiles import StaticFiles
 
 
4
 
5
+ # Create a FastAPI app
6
+ app = FastAPI()
 
 
 
 
7
 
8
+ # Serve the React build folder
9
+ app.mount("/", StaticFiles(directory="build", html=True), name="static")
 
 
10
 
11
+ # Define a dummy Gradio interface (optional)
12
+ def greet(name):
13
+ return f"Hello, {name}! Welcome to Flight Timings."
14
 
15
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
16
 
17
+ # Add Gradio to FastAPI
18
+ app = gr.mount_gradio_app(app, demo, path="/gradio")
 
19