diff --git a/mindeye/scripts/utils_mindeye.py b/mindeye/scripts/utils_mindeye.py index 0f5e1b9..88b892c 100644 --- a/mindeye/scripts/utils_mindeye.py +++ b/mindeye/scripts/utils_mindeye.py @@ -1,19 +1,19 @@ +import json +import math +import os +import random +import time + +import matplotlib.pyplot as plt import numpy as np -from torchvision import transforms +import PIL +import requests import torch import torch.nn as nn import torch.nn.functional as F -import PIL -import random -import os -import matplotlib.pyplot as plt -import math import webdataset as wds - -import json from PIL import Image -import requests -import time +from torchvision import transforms device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -251,6 +251,8 @@ def select_annotations(annots, random=True): return txt from generative_models.sgm.util import append_dims + + def unclip_recon(x, diffusion_engine, vector_suffix, num_samples=1, offset_noise_level=0.04): assert x.ndim==3 @@ -506,6 +508,8 @@ def create_design_matrix(images, starts, is_new_run, unique_images, n_runs, n_tr from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder from scipy import stats from tqdm import tqdm + + def calculate_retrieval_metrics(all_clip_voxels, all_images): print("Loading clip_img_embedder") try: @@ -570,7 +574,11 @@ def calculate_retrieval_metrics(all_clip_voxels, all_images): return all_fwd_acc[0], all_bwd_acc[0] -from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names +from torchvision.models.feature_extraction import ( + create_feature_extractor, + get_graph_node_names, +) + @torch.no_grad() def two_way_identification(all_recons, all_images, model, preprocess, feature_layer=None, return_avg=True): @@ -627,6 +635,7 @@ def calculate_pixcorr(all_recons, all_images): from skimage.color import rgb2gray from skimage.metrics import structural_similarity as ssim + def calculate_ssim(all_recons, all_images): preprocess = transforms.Compose([ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), @@ -645,7 +654,9 @@ def calculate_ssim(all_recons, all_images): print(f"SSIM: {ssim_}") return ssim_ -from torchvision.models import alexnet, AlexNet_Weights +from torchvision.models import AlexNet_Weights, alexnet + + def calculate_alexnet(all_recons, all_images, layers = [2, 5]): print("Loading AlexNet") alex_weights = AlexNet_Weights.DEFAULT @@ -680,7 +691,9 @@ def calculate_alexnet(all_recons, all_images, layers = [2, 5]): return alexnet2, alexnet5 -from torchvision.models import inception_v3, Inception_V3_Weights +from torchvision.models import Inception_V3_Weights, inception_v3 + + def calculate_inception_v3(all_recons, all_images): print("Loading Inception V3") weights = Inception_V3_Weights.DEFAULT @@ -705,6 +718,8 @@ def calculate_inception_v3(all_recons, all_images): import clip as clip_torch + + def calculate_clip(all_recons, all_images): print("Loading CLIP") clip_model, preprocess = clip_torch.load("ViT-L/14", device=device) @@ -723,7 +738,9 @@ def calculate_clip(all_recons, all_images): return clip_ import scipy as sp -from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights +from torchvision.models import EfficientNet_B1_Weights, efficientnet_b1 + + def calculate_efficientnet_b1(all_recons, all_images): print("Loading EfficientNet B1") weights = EfficientNet_B1_Weights.DEFAULT @@ -965,4 +982,302 @@ def vectorized_pearsonr(X, Y): # Handle cases where r is NaN due to division by zero p = np.where(np.isnan(p), 1.0, p) - return r, p \ No newline at end of file + return r, p + +# inference helpers + +def save_figures(gt_images, reconsTR, retrievals, n_retrievals=5, save_path=None): + """Plots ground truth, reconstructed, and retrieved images in a grid. + + Args: + - gt_images (list or np.ndarray): ground truth images, shape (N, 3, 224, 224) + - reconsTR (list or np.ndarray): reconstructions, shape (N, 3, 224, 224) + - retrievals (list[dict]): list of dicts, each with top N retrieved images (as arrays) + - n_retrievals (int): number of retrievals to display per row + - save_path (str or None): if provided, saves the figure to this path + """ + N = len(reconsTR) + ncols = 2 + n_retrievals # GT, Recon, Retrievals + nrows = N + fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 2.5*nrows)) + if nrows == 1: + axes = np.expand_dims(axes, 0) + for i in range(N): + # Ground truth + ax = axes[i, 0] + img = gt_images[i] + if isinstance(img, torch.Tensor): + img = img.detach().cpu().numpy() + if img.shape[0] == 3: + img = np.transpose(img, (1, 2, 0)) + ax.imshow(np.clip(img, 0, 1)) + ax.set_title('GT') + ax.axis('off') + # Reconstruction + ax = axes[i, 1] + rec = reconsTR[i] + if isinstance(rec, torch.Tensor): + rec = rec.detach().cpu().numpy() + if rec.shape[0] == 3: + rec = np.transpose(rec, (1, 2, 0)) + ax.imshow(np.clip(rec, 0, 1)) + ax.set_title('Recon') + ax.axis('off') + # Retrievals + retrieved_imgs = [retrievals[i][f'attempt{idx}'] for idx in range(1, n_retrievals + 1)] + for j, ret_img in enumerate(retrieved_imgs): + ax = axes[i, 2 + j] + if isinstance(ret_img, torch.Tensor): + ret_img = ret_img.detach().cpu().numpy() + if ret_img.shape[0] == 3: + ret_img = np.transpose(ret_img, (1, 2, 0)) + ax.imshow(np.clip(ret_img, 0, 1)) + ax.set_title(f'Retr {j+1}') + ax.axis('off') + plt.tight_layout() + if save_path: + plt.savefig(save_path, bbox_inches='tight') + plt.show() + +def do_reconstructions( + betas_tt: torch.Tensor, + model: nn.Module, + diffusion_engine: nn.Module, + vector_suffix: torch.Tensor, + device: torch.device = device, + num_samples_per_image: int = 1, + imsize: int = 224, + timesteps: int = 20, + ): + """ + takes in the beta map for a stimulus trial in torch tensor format (tt) + + returns reconstructions and clipvoxels for retrievals + + Args: + betas_tt: torch tensor of shape (num_images, 1, num_voxels) + model: the trained MindEye model + diffusion_engine: the diffusion engine used for reconstruction + vector_suffix: the vector suffix used in unclip_recon + device: the device to run the model on (default: DEVICE) + """ + model.to(device) + model.eval() + + clipvoxelsTR = None + reconsTR = [] + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16): + voxel = betas_tt + voxel = voxel.to(device) + + # `voxel[:, [0]]` is of size (num_images, num_voxels) + voxel_ridge = model.ridge(voxel[:,[0]], 0) # 0th index of subj_list + + backbone, clipvoxelsTR, _ = model.backbone(voxel_ridge) + + prior_out = model.diffusion_prior.p_sample_loop( + backbone.shape, + text_cond = dict(text_embed = backbone), + cond_scale = 1., timesteps = timesteps + ) + resize_fn = transforms.Resize((imsize, imsize)) + reconsTR = [ + resize_fn( + unclip_recon( + prior_out[[i]], + diffusion_engine, + vector_suffix, + num_samples=num_samples_per_image + ).float().cpu() + ) + for i in tqdm(range(len(voxel))) + ] + reconsTR = torch.vstack(reconsTR).cpu() + + return reconsTR, clipvoxelsTR.cpu() + + +def get_top_retrievals( + clipvoxel: torch.Tensor, + all_images: torch.Tensor, + clip_img_embedder: nn.Module, + total_retrievals = 1, + imsize: int = 224, + device: torch.device = device, + ): + """Given a clipvoxel embedding, retrieve the top matching images from all_images. + + Args: + clipvoxel: torch tensor of shape (num_images, clip_dim) + all_images: torch tensor of shape (num_database_images, 3, H, W) + clip_img_embedder: the CLIP image embedder model + total_retrievals: number of top retrievals to return per clipvoxel + imsize: size to which images are resized before returning + device: the device to run the model on (default: DEVICE) + Returns: + List of dictionaries, each containing the top retrieved images for each clipvoxel. + """ + clip_img_embedder.to(device) + clip_img_embedder.eval() + + retrieved_results = [] + with torch.no_grad(): + all_imgs_emb = clip_img_embedder( + torch.reshape( + all_images, + (all_images.shape[0], 3, imsize, imsize) + ).to(device) + ).cpu() # CLIP img embeddings + all_imgs_emb = all_imgs_emb.reshape(len(all_imgs_emb),-1).float() + clipvoxels_emb = clipvoxel.reshape(clipvoxel.shape[0], -1).float() + + fwd_sim = batchwise_cosine_similarity(all_imgs_emb, clipvoxels_emb) # brain, clip + print("Given Brain embedding, find correct Image embedding") + fwd_sim = np.array(fwd_sim.cpu()) + which = np.flip(np.argsort(fwd_sim, axis=1), axis=1) + + + for img_idx in range(fwd_sim.shape[0]): + values_dict = {} + for attempt in range(total_retrievals): + retrieved_image = all_images[which[img_idx][attempt]] + values_dict[f"attempt{(attempt+1)}"] = transforms.Resize((imsize,imsize))(retrieved_image).float().numpy() + retrieved_results.append(values_dict) + return retrieved_results + + +def do_reconstruction_and_retrievals( + betas: torch.Tensor, + all_images: torch.Tensor, + model: nn.Module, + diffusion_engine: nn.Module, + vector_suffix: torch.Tensor, + clip_img_embedder: nn.Module, + MST_idx: list | None = None, + mst_only: bool = False, + device: torch.device = device, + normalize_betas: bool = False, + total_retrievals: int = 5, + save_fig_path: str | None = None, +) -> dict: + """Given betas, performs reconstruction and retrievals. + + Args: + betas: torch tensor of shape (num_images, 1, num_voxels + all_images: torch tensor of shape (num_database_images, 3, H, W) + model: the trained MindEye model + diffusion_engine: the diffusion engine used for reconstruction + vector_suffix: the vector suffix used in unclip_recon + clip_img_embedder: the CLIP image embedder model for retrieval + device: the device to run the model on (default: device) + normalize_betas: whether to z-score normalize betas before reconstruction + total_retrievals: number of top retrievals to return per reconstruction + save_fig_path: if provided, saves a figure of results to this path + + Returns: + Dictionary containing reconstructions, clipvoxels, and retrievals. + """ + if isinstance(betas, np.ndarray): + betas = torch.FloatTensor(betas) + betas = betas.to(device) + if normalize_betas: + betas_mean = torch.mean(betas, dim=0) + betas_std = torch.std(betas, dim=0) + betas = (betas - betas_mean) / (betas_std + 1e-6) + betas = betas.unsqueeze(1) + + if mst_only: + if MST_idx is None: + raise ValueError("MST_idx must be provided if mst_only is True") + + all_images = all_images[MST_idx] + # TODO: betas = betas[MST_idx]? + + # remove duplicate images + seen = set() + unique_indices = [] + for i, img in enumerate(all_images): + img_tuple = tuple(img.flatten().tolist()) + if img_tuple not in seen: + seen.add(img_tuple) + unique_indices.append(i) + all_images = all_images[unique_indices] + + + print("Running reconstruction...") + reconsTR, clipvoxelsTR = do_reconstructions( + betas_tt=betas, + model=model, + diffusion_engine=diffusion_engine, + vector_suffix=vector_suffix, + device=device, + ) + print("Done reconstruction..") + + print("Doing retrieval..") + retrieval_value_dict = get_top_retrievals( + clipvoxel=clipvoxelsTR, + all_images=all_images, + clip_img_embedder=clip_img_embedder, + device=device, + total_retrievals=total_retrievals, + ) + print("done retrieval...") + + if save_fig_path is not None: + save_figures( + gt_images=all_images[:len(reconsTR)], + reconsTR=reconsTR, + retrievals=retrieval_value_dict, + n_retrievals=total_retrievals, + save_path=save_fig_path, + ) + output = { + "recons": reconsTR, + "clipvoxels": clipvoxelsTR, + "retrievals": retrieval_value_dict + } + return output + +def evaluate_results(recons, gt_images, clipvoxels, metrics_to_run=None): + """ + Evaluate reconstructions and retrievals using various metrics + + Args: + recons (torch Tensor): reconstructed images of shape (n_cond, 3, 224, 224) + gt_images (torch Tensor): ground truth images of shape (n_cond, 3, 224, 224) + clipvoxels: clip voxels embeddings for retrieval metrics + metrics_to_run (list): List of metric names to compute + + Returns: + dict: Dictionary of evaluation metrics for each subset + """ + if metrics_to_run is None: + metrics_to_run = ['alexnet', 'clip', 'inception', 'pixcorr', 'ssim', 'retrieval'] + + metrics = {} + + if 'alexnet' in metrics_to_run: + alex2, alex5 = calculate_alexnet(recons, gt_images) + metrics['alexnet_layer2'] = alex2 + metrics['alexnet_layer5'] = alex5 + + if 'clip' in metrics_to_run: + metrics['clip'] = calculate_clip(recons, gt_images) + + if 'inception' in metrics_to_run: + metrics['inception'] = calculate_inception_v3(recons, gt_images) + + if 'pixcorr' in metrics_to_run: + metrics['pixcorr'] = calculate_pixcorr(recons, gt_images) + + if 'ssim' in metrics_to_run: + metrics['ssim'] = calculate_ssim(recons, gt_images) + + if 'retrieval' in metrics_to_run and clipvoxels is not None: + fwd_acc, bwd_acc = calculate_retrieval_metrics(clipvoxels, gt_images) + metrics['retrieval_forward'] = fwd_acc + metrics['retrieval_backward'] = bwd_acc + + return metrics