fossil_app / closest_sample.py
Yuxiang Wang
reference bucket img for closest sample
af9c1e6
raw
history blame
3.23 kB
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
fossils_pd= pd.read_csv('fossils_paths.csv')
def pca_distance(pca,sample,embedding):
"""
Args:
pca:fitted PCA model
sample:sample for which to find the closest embeddings
embedding:embeddings of the dataset
Returns:
The indices of the five closest embeddings to the sample
"""
s = pca.transform(sample.reshape(1,-1))
all = pca.transform(embedding[:,-1])
distances = np.linalg.norm(all - s, axis=1)
#print(distances)
return np.argsort(distances)[:5]
def return_paths(argsorted,files):
paths= []
for i in argsorted:
paths.append(files[i])
return paths
def download_public_image(url, destination_path):
response = requests.get(url)
if response.status_code == 200:
with open(destination_path, 'wb') as f:
f.write(response.content)
print(f"Downloaded image to {destination_path}")
return True
else:
print(f"Failed to download image from bucket. Status code: {response.status_code}")
return False
def get_images(embedding):
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
public_path_florissant = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
success = download_public_image(public_path_florissant, local_file_path)
if not success:
public_path_general = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_general)
download_public_image(public_path_general, local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
return paths