Spaces:
Runtime error
Runtime error
add multi-class
Browse files- app.py +28 -5
- output.png +0 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -17,10 +17,31 @@ nltk.download('averaged_perceptron_tagger')
|
|
| 17 |
from nltk.tokenize import word_tokenize
|
| 18 |
import torchvision
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
args = default_argument_parser().parse_args()
|
| 21 |
cfg = setup(args)
|
| 22 |
|
| 23 |
-
multi_classes =
|
| 24 |
|
| 25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
|
|
@@ -42,10 +63,12 @@ def run(sketch, caption, threshold, seed):
|
|
| 42 |
|
| 43 |
# set the condidate classes here
|
| 44 |
caption = caption.replace('\n',' ')
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
if len(classes) ==0 or multi_classes == False:
|
| 50 |
classes = [caption]
|
| 51 |
|
|
|
|
| 17 |
from nltk.tokenize import word_tokenize
|
| 18 |
import torchvision
|
| 19 |
|
| 20 |
+
import spacy
|
| 21 |
+
|
| 22 |
+
# download the model
|
| 23 |
+
spacy.cli.download("en_core_web_sm")
|
| 24 |
+
|
| 25 |
+
# Load spaCy model
|
| 26 |
+
nlp = spacy.load("en_core_web_sm")
|
| 27 |
+
|
| 28 |
+
def extract_objects(prompt):
|
| 29 |
+
doc = nlp(prompt)
|
| 30 |
+
# Extract object nouns (including proper nouns and compound nouns)
|
| 31 |
+
objects = set()
|
| 32 |
+
for token in doc:
|
| 33 |
+
# Check if the token is a noun or part of a named entity
|
| 34 |
+
if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_:
|
| 35 |
+
objects.add(token.text)
|
| 36 |
+
# Check if the token is part of a compound noun
|
| 37 |
+
if token.dep_ in {"compound"}:
|
| 38 |
+
objects.add(token.head.text)
|
| 39 |
+
return list(objects)
|
| 40 |
+
|
| 41 |
args = default_argument_parser().parse_args()
|
| 42 |
cfg = setup(args)
|
| 43 |
|
| 44 |
+
multi_classes = True
|
| 45 |
|
| 46 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False)
|
|
|
|
| 63 |
|
| 64 |
# set the condidate classes here
|
| 65 |
caption = caption.replace('\n',' ')
|
| 66 |
+
classes = extract_objects(caption)
|
| 67 |
+
# translator = str.maketrans('', '', string.punctuation)
|
| 68 |
+
# caption = caption.translate(translator).lower()
|
| 69 |
+
# words = word_tokenize(caption)
|
| 70 |
+
# classes = get_noun_phrase(words)
|
| 71 |
+
# print(classes)
|
| 72 |
if len(classes) ==0 or multi_classes == False:
|
| 73 |
classes = [caption]
|
| 74 |
|
output.png
CHANGED
|
|
requirements.txt
CHANGED
|
@@ -10,4 +10,5 @@ iopath
|
|
| 10 |
ftfy
|
| 11 |
fvcore
|
| 12 |
regex
|
| 13 |
-
nltk
|
|
|
|
|
|
| 10 |
ftfy
|
| 11 |
fvcore
|
| 12 |
regex
|
| 13 |
+
nltk
|
| 14 |
+
spacy
|