todap commited on
Commit
9cfa91c
·
verified ·
1 Parent(s): 3a909aa

Upload 5 files

Browse files
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/identification_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
+ from PIL import Image
3
+ import torch
4
+
5
+ class IdentificationModel:
6
+ def __init__(self):
7
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
8
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ self.model.to(self.device)
11
+
12
+ def identify_objects(self, image_path, text_descriptions):
13
+ # Load image
14
+ image = Image.open(image_path)
15
+
16
+ # Prepare inputs
17
+ inputs = self.processor(text=text_descriptions, images=image, return_tensors="pt", padding=True)
18
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
19
+
20
+ # Run inference
21
+ with torch.no_grad():
22
+ outputs = self.model(**inputs)
23
+
24
+ # Get logits and compute probabilities
25
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
26
+ probs = logits_per_image.softmax(dim=1) # convert logits to probabilities
27
+
28
+ # Find the detection with the maximum probability
29
+ max_prob, max_idx = torch.max(probs[0], dim=0)
30
+
31
+ # Prepare the result for the highest probability detection
32
+ detection=[]
33
+ detection.append({
34
+ 'description': text_descriptions[max_idx],
35
+ 'probability': float(max_prob)
36
+ })
37
+
38
+ return detection
39
+
40
+
models/segmentation_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import numpy as np
3
+ import torchvision.transforms as transforms
4
+
5
+ class SegmentationModel:
6
+ def __init__(self):
7
+ self.model = YOLO('yolov8m-seg.pt')
8
+ self.transform = transforms.Compose([
9
+ transforms.Resize((640, 640)), # Resize to YOLOv8 input size
10
+ transforms.Lambda(lambda x: x.mul(255).byte()), # Scale to 0-255 and convert to uint8
11
+ transforms.Lambda(lambda x: x.permute(1, 2, 0).numpy()) # Change from BCHW to HWC
12
+ ])
13
+
14
+ def segment_image(self, image_path):
15
+
16
+ results = self.model(image_path, conf=0.25)
17
+ class_name=[]
18
+ if results[0].masks is not None:
19
+ for counter, detection in enumerate(results[0].masks.data):
20
+ cls_id = int(results[0].boxes[counter].cls.item())
21
+ class_name.append(self.model.names[cls_id])
22
+ print(class_name)
23
+
24
+
25
+
26
+ # Extract masks, boxes, and labels
27
+ result = results[0]
28
+ masks = result.masks.data.cpu().numpy() if result.masks is not None else np.array([])
29
+ boxes = result.boxes.xyxy.cpu().numpy() if result.boxes is not None else np.array([])
30
+ labels = result.boxes.cls.cpu().numpy() if result.boxes is not None else np.array([])
31
+
32
+ return masks, boxes, labels, class_name
models/summarization_model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartForConditionalGeneration, BartTokenizer
2
+
3
+ class SummarizationModel:
4
+ def __init__(self):
5
+ self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
6
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
7
+
8
+ def summarize(self, text):
9
+ # Split the text into lines and remove empty lines
10
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
11
+
12
+ # If there's only one line, return it as is
13
+ if len(lines) <= 1:
14
+ return text.strip()
15
+
16
+ # Otherwise, proceed with summarization
17
+ inputs = self.tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
18
+ summary_ids = self.model.generate(inputs["input_ids"], num_beams=4, max_length=100, early_stopping=True)
19
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
models/text_extraction_model.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import easyocr
2
+
3
+ class TextExtractionModel:
4
+ def __init__(self):
5
+ self.reader = easyocr.Reader(['en'])
6
+
7
+ def extract_text(self, image_path):
8
+ result = self.reader.readtext(image_path)
9
+ return ' '.join([detection[1] for detection in result])