diff --git a/.gitignore b/.gitignore index f5b2f05..d5e502e 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ package-lock.json # MacOS stuff .DS_Store .vscode +/venv \ No newline at end of file diff --git a/mindeye/scripts/mindeye.ipynb b/mindeye/scripts/mindeye.ipynb index 0845925..56191d4 100644 --- a/mindeye/scripts/mindeye.ipynb +++ b/mindeye/scripts/mindeye.ipynb @@ -2,26 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 40, "id": "b6053a83-2259-475e-9e21-201e44217e88", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/conf/.venv/lib/python3.11/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "line 6: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n", - "line 6: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n", - "line 14: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n", - "line 14: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], @@ -102,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 41, "id": "699a3162", "metadata": {}, "outputs": [], @@ -117,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 42, "id": "4516e788-85cc-42ab-b05a-11bd7207f6ba", "metadata": {}, "outputs": [], @@ -140,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 43, "id": "12be1838-f387-4cdd-b7cb-217a74501359", "metadata": {}, "outputs": [ @@ -171,7 +161,7 @@ "722060552" ] }, - "execution_count": 4, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -314,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 44, "id": "8a627d35-3cd5-4cd1-9bb3-c02c0c97f7a1", "metadata": {}, "outputs": [], @@ -341,10 +331,23 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 45, "id": "05bd11f3-6d4d-4ee4-a23b-443afeb5c3fe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[45]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 14\u001b[39m sampler_config[\u001b[33m'\u001b[39m\u001b[33mparams\u001b[39m\u001b[33m'\u001b[39m][\u001b[33m'\u001b[39m\u001b[33mnum_steps\u001b[39m\u001b[33m'\u001b[39m] = \u001b[32m38\u001b[39m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstorage_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/diffusion_engine\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mrb\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m input_file:\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m diffusion_engine = \u001b[43mpickle\u001b[49m\u001b[43m.\u001b[49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_file\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[38;5;66;03m# set to inference\u001b[39;00m\n\u001b[32m 18\u001b[39m diffusion_engine.eval().requires_grad_(\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torch/storage.py:336\u001b[39m, in \u001b[36m_load_from_bytes\u001b[39m\u001b[34m(b)\u001b[39m\n\u001b[32m 332\u001b[39m \u001b[38;5;129m@_share_memory_lock_protected\u001b[39m\n\u001b[32m 333\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_share_filename_cpu_\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m 334\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()._share_filename_cpu_(*args, **kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m336\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_load_from_bytes\u001b[39m(b):\n\u001b[32m 337\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m torch.load(io.BytesIO(b))\n\u001b[32m 340\u001b[39m _StorageBase.type = _type \u001b[38;5;66;03m# type: ignore[assignment]\u001b[39;00m\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], "source": [ "# prep unCLIP\n", "config = OmegaConf.load(f\"{project_path}/models/generative_models/configs/unclip6.yaml\")\n", @@ -379,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 46, "id": "58cb6183", "metadata": {}, "outputs": [], @@ -399,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 47, "id": "48586a97", "metadata": {}, "outputs": [ @@ -408,7 +411,7 @@ "output_type": "stream", "text": [ "Data shape: (780, 109)\n", - "Using design file: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", + "Using design file: /home/amaarc/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", "Total number of images: 770\n", "Number of unique images: 126\n", "n_runs 11\n", @@ -512,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 48, "id": "51d16a03", "metadata": {}, "outputs": [ @@ -520,9 +523,9 @@ "name": "stderr", "output_type": "stream", "text": [ - " 6%|▌ | 42/693 [00:00<00:01, 412.55it/s]/home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/conf/.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n", + " 8%|▊ | 57/693 [00:00<00:05, 118.14it/s]/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n", " warnings.warn(\n", - "100%|██████████| 693/693 [00:06<00:00, 104.06it/s]" + "100%|██████████| 693/693 [00:20<00:00, 33.45it/s]" ] }, { @@ -571,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 49, "id": "f692b9cb", "metadata": {}, "outputs": [], @@ -590,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 50, "id": "fbffed78", "metadata": {}, "outputs": [ @@ -599,7 +602,7 @@ "output_type": "stream", "text": [ "Data shape: (780, 109)\n", - "Using design file: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", + "Using design file: /home/amaarc/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", "Total number of images: 770\n", "Number of unique images: 126\n" ] @@ -624,7 +627,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 51, "id": "31b2474d", "metadata": {}, "outputs": [ @@ -647,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 52, "id": "e05736bc-c816-49ae-8718-b6c31b412781", "metadata": {}, "outputs": [], @@ -676,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 53, "id": "b21c9550", "metadata": {}, "outputs": [], @@ -688,7 +691,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 54, "id": "4768e42a", "metadata": {}, "outputs": [], @@ -719,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 55, "id": "057b3dd3", "metadata": {}, "outputs": [ @@ -739,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 56, "id": "895d9228-46e0-4ec0-9fe4-00f802f9708f", "metadata": {}, "outputs": [], @@ -834,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 57, "id": "4e595bc9", "metadata": {}, "outputs": [], @@ -847,7 +850,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 58, "id": "ff1807bc", "metadata": {}, "outputs": [], @@ -859,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 59, "id": "211a4d40-643d-493b-874b-2030490b9bf4", "metadata": { "tags": [] @@ -1313,10 +1316,502 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 66, "id": "d6675825", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calculating retrieval subset 0 (first set of repeats)\n", + "Loading clip_img_embedder\n", + "The total pool of images and clip voxels to do retrieval on is: 62\n", + "loading cached embeddings from /home/amaarc/rtcloud-projects/mindeye/3t/derivatives/image_embeddings_cache/image_embeddings_62.pt\n", + "calculating retrieval metrics\n", + "overall fwd percent_correct: 0.2258\n", + "overall bwd percent_correct: 0.2903\n", + "\n", + "pair 1:\n", + " image 1: pair_46_w_pool2.jpg\n", + " image 2: pair_46_w_pool1.jpg\n", + " img 1 brain embedding chose: pair_46_w_pool2.jpg Correct\n", + " img 2 brain embedding chose: pair_46_w_pool1.jpg Correct\n", + "\n", + "pair 2:\n", + " image 1: pair_14_18_78.png\n", + " image 2: pair_14_18_22.png\n", + " img 1 brain embedding chose: pair_14_18_22.png Incorrect\n", + " img 2 brain embedding chose: pair_14_18_22.png Correct\n", + "\n", + "pair 3:\n", + " image 1: pair_39_w_corridor1.jpg\n", + " image 2: pair_39_w_corridor2.jpg\n", + " img 1 brain embedding chose: pair_39_w_corridor1.jpg Correct\n", + " img 2 brain embedding chose: pair_39_w_corridor2.jpg Correct\n", + "\n", + "pair 4:\n", + " image 1: pair_5_8_0.png\n", + " image 2: pair_5_8_78.png\n", + " img 1 brain embedding chose: pair_5_8_0.png Correct\n", + " img 2 brain embedding chose: pair_5_8_78.png Correct\n", + "\n", + "pair 5:\n", + " image 1: pair_37_w_airplane_interior2.jpg\n", + " image 2: pair_37_w_airplane_interior1.jpg\n", + " img 1 brain embedding chose: pair_37_w_airplane_interior1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_37_w_airplane_interior2.jpg Incorrect\n", + "\n", + "pair 6:\n", + " image 1: pair_17_21_33.png\n", + " image 2: pair_17_21_0.png\n", + " img 1 brain embedding chose: pair_17_21_33.png Correct\n", + " img 2 brain embedding chose: pair_17_21_0.png Correct\n", + "\n", + "pair 7:\n", + " image 1: pair_12_16_22.png\n", + " image 2: pair_12_16_78.png\n", + " img 1 brain embedding chose: pair_12_16_22.png Correct\n", + " img 2 brain embedding chose: pair_12_16_78.png Correct\n", + "\n", + "pair 8:\n", + " image 1: pair_49_w_yoga_studio1.jpg\n", + " image 2: pair_49_w_yoga_studio2.jpg\n", + " img 1 brain embedding chose: pair_49_w_yoga_studio1.jpg Correct\n", + " img 2 brain embedding chose: pair_49_w_yoga_studio1.jpg Incorrect\n", + "\n", + "pair 9:\n", + " image 1: pair_3_5_100.png\n", + " image 2: pair_3_5_33.png\n", + " img 1 brain embedding chose: pair_3_5_100.png Correct\n", + " img 2 brain embedding chose: pair_3_5_33.png Correct\n", + "\n", + "pair 10:\n", + " image 1: pair_10_14_89.png\n", + " image 2: pair_10_14_33.png\n", + " img 1 brain embedding chose: pair_10_14_89.png Correct\n", + " img 2 brain embedding chose: pair_10_14_33.png Correct\n", + "\n", + "pair 11:\n", + " image 1: pair_7_10_56.png\n", + " image 2: pair_7_10_89.png\n", + " img 1 brain embedding chose: pair_7_10_89.png Incorrect\n", + " img 2 brain embedding chose: pair_7_10_89.png Correct\n", + "\n", + "pair 12:\n", + " image 1: pair_1_3_56.png\n", + " image 2: pair_1_3_89.png\n", + " img 1 brain embedding chose: pair_1_3_89.png Incorrect\n", + " img 2 brain embedding chose: pair_1_3_89.png Correct\n", + "\n", + "pair 13:\n", + " image 1: pair_6_9_56.png\n", + " image 2: pair_6_9_11.png\n", + " img 1 brain embedding chose: pair_6_9_56.png Correct\n", + " img 2 brain embedding chose: pair_6_9_56.png Incorrect\n", + "\n", + "pair 14:\n", + " image 1: pair_42_w_icerink1.jpg\n", + " image 2: pair_42_w_icerink2.jpg\n", + " img 1 brain embedding chose: pair_42_w_icerink1.jpg Correct\n", + " img 2 brain embedding chose: pair_42_w_icerink2.jpg Correct\n", + "\n", + "pair 15:\n", + " image 1: pair_40_w_escalator1.jpg\n", + " image 2: pair_40_w_escalator2.jpg\n", + " img 1 brain embedding chose: pair_40_w_escalator1.jpg Correct\n", + " img 2 brain embedding chose: pair_40_w_escalator2.jpg Correct\n", + "\n", + "pair 16:\n", + " image 1: pair_8_12_56.png\n", + " image 2: pair_8_12_0.png\n", + " img 1 brain embedding chose: pair_8_12_56.png Correct\n", + " img 2 brain embedding chose: pair_8_12_0.png Correct\n", + "\n", + "pair 17:\n", + " image 1: pair_45_w_pagoda2.jpg\n", + " image 2: pair_45_w_pagoda1.jpg\n", + " img 1 brain embedding chose: pair_45_w_pagoda1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_45_w_pagoda1.jpg Correct\n", + "\n", + "pair 18:\n", + " image 1: pair_38_w_arch2.jpg\n", + " image 2: pair_38_w_arch1.jpg\n", + " img 1 brain embedding chose: pair_38_w_arch2.jpg Correct\n", + " img 2 brain embedding chose: pair_38_w_arch1.jpg Correct\n", + "\n", + "pair 19:\n", + " image 1: pair_47_w_roller_coaster1.jpg\n", + " image 2: pair_47_w_roller_coaster2.jpg\n", + " img 1 brain embedding chose: pair_47_w_roller_coaster1.jpg Correct\n", + " img 2 brain embedding chose: pair_47_w_roller_coaster2.jpg Correct\n", + "\n", + "pair 20:\n", + " image 1: pair_0_1_78.png\n", + " image 2: pair_0_1_22.png\n", + " img 1 brain embedding chose: pair_0_1_78.png Correct\n", + " img 2 brain embedding chose: pair_0_1_22.png Correct\n", + "\n", + "pair 21:\n", + " image 1: pair_13_17_89.png\n", + " image 2: pair_13_17_56.png\n", + " img 1 brain embedding chose: pair_13_17_89.png Correct\n", + " img 2 brain embedding chose: pair_13_17_56.png Correct\n", + "\n", + "pair 22:\n", + " image 1: pair_43_w_lighthouse2.jpg\n", + " image 2: pair_43_w_lighthouse1.jpg\n", + " img 1 brain embedding chose: pair_43_w_lighthouse2.jpg Correct\n", + " img 2 brain embedding chose: pair_43_w_lighthouse1.jpg Correct\n", + "\n", + "pair 23:\n", + " image 1: pair_16_20_56.png\n", + " image 2: pair_16_20_0.png\n", + " img 1 brain embedding chose: pair_16_20_56.png Correct\n", + " img 2 brain embedding chose: pair_16_20_56.png Incorrect\n", + "\n", + "pair 24:\n", + " image 1: pair_11_15_89.png\n", + " image 2: pair_11_15_44.png\n", + " img 1 brain embedding chose: pair_11_15_44.png Incorrect\n", + " img 2 brain embedding chose: pair_11_15_44.png Correct\n", + "\n", + "pair 25:\n", + " image 1: pair_4_6_89.png\n", + " image 2: pair_4_6_44.png\n", + " img 1 brain embedding chose: pair_4_6_44.png Incorrect\n", + " img 2 brain embedding chose: pair_4_6_44.png Correct\n", + "\n", + "pair 26:\n", + " image 1: pair_2_4_89.png\n", + " image 2: pair_2_4_22.png\n", + " img 1 brain embedding chose: pair_2_4_89.png Correct\n", + " img 2 brain embedding chose: pair_2_4_22.png Correct\n", + "\n", + "pair 27:\n", + " image 1: pair_44_w_log_cabin2.jpg\n", + " image 2: pair_44_w_log_cabin1.jpg\n", + " img 1 brain embedding chose: pair_44_w_log_cabin2.jpg Correct\n", + " img 2 brain embedding chose: pair_44_w_log_cabin1.jpg Correct\n", + "\n", + "pair 28:\n", + " image 1: pair_48_w_runway1.jpg\n", + " image 2: pair_48_w_runway2.jpg\n", + " img 1 brain embedding chose: pair_48_w_runway1.jpg Correct\n", + " img 2 brain embedding chose: pair_48_w_runway2.jpg Correct\n", + "\n", + "pair 29:\n", + " image 1: pair_9_13_44.png\n", + " image 2: pair_9_13_89.png\n", + " img 1 brain embedding chose: pair_9_13_44.png Correct\n", + " img 2 brain embedding chose: pair_9_13_44.png Incorrect\n", + "\n", + "pair 30:\n", + " image 1: pair_15_19_56.png\n", + " image 2: pair_15_19_0.png\n", + " img 1 brain embedding chose: pair_15_19_56.png Correct\n", + " img 2 brain embedding chose: pair_15_19_0.png Correct\n", + "\n", + "pair 31:\n", + " image 1: pair_41_w_gym2.jpg\n", + " image 2: pair_41_w_gym1.jpg\n", + " img 1 brain embedding chose: pair_41_w_gym2.jpg Correct\n", + " img 2 brain embedding chose: pair_41_w_gym2.jpg Incorrect\n", + "\n", + "correct choices: 49/62\n", + "mst 2afc score: 0.7903 (79.03%)\n", + "\n", + "\n", + "calculating retrieval subset 1 (second set of repeats)\n", + "Loading clip_img_embedder\n", + "The total pool of images and clip voxels to do retrieval on is: 62\n", + "loading cached embeddings from /home/amaarc/rtcloud-projects/mindeye/3t/derivatives/image_embeddings_cache/image_embeddings_62.pt\n", + "calculating retrieval metrics\n", + "overall fwd percent_correct: 0.3065\n", + "overall bwd percent_correct: 0.3710\n", + "\n", + "pair 1:\n", + " image 1: pair_46_w_pool2.jpg\n", + " image 2: pair_46_w_pool1.jpg\n", + " img 1 brain embedding chose: pair_46_w_pool1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_46_w_pool1.jpg Correct\n", + "\n", + "pair 2:\n", + " image 1: pair_14_18_78.png\n", + " image 2: pair_14_18_22.png\n", + " img 1 brain embedding chose: pair_14_18_78.png Correct\n", + " img 2 brain embedding chose: pair_14_18_78.png Incorrect\n", + "\n", + "pair 3:\n", + " image 1: pair_39_w_corridor1.jpg\n", + " image 2: pair_39_w_corridor2.jpg\n", + " img 1 brain embedding chose: pair_39_w_corridor1.jpg Correct\n", + " img 2 brain embedding chose: pair_39_w_corridor2.jpg Correct\n", + "\n", + "pair 4:\n", + " image 1: pair_5_8_0.png\n", + " image 2: pair_5_8_78.png\n", + " img 1 brain embedding chose: pair_5_8_0.png Correct\n", + " img 2 brain embedding chose: pair_5_8_0.png Incorrect\n", + "\n", + "pair 5:\n", + " image 1: pair_37_w_airplane_interior2.jpg\n", + " image 2: pair_37_w_airplane_interior1.jpg\n", + " img 1 brain embedding chose: pair_37_w_airplane_interior2.jpg Correct\n", + " img 2 brain embedding chose: pair_37_w_airplane_interior1.jpg Correct\n", + "\n", + "pair 6:\n", + " image 1: pair_17_21_33.png\n", + " image 2: pair_17_21_0.png\n", + " img 1 brain embedding chose: pair_17_21_33.png Correct\n", + " img 2 brain embedding chose: pair_17_21_33.png Incorrect\n", + "\n", + "pair 7:\n", + " image 1: pair_12_16_22.png\n", + " image 2: pair_12_16_78.png\n", + " img 1 brain embedding chose: pair_12_16_22.png Correct\n", + " img 2 brain embedding chose: pair_12_16_78.png Correct\n", + "\n", + "pair 8:\n", + " image 1: pair_49_w_yoga_studio1.jpg\n", + " image 2: pair_49_w_yoga_studio2.jpg\n", + " img 1 brain embedding chose: pair_49_w_yoga_studio1.jpg Correct\n", + " img 2 brain embedding chose: pair_49_w_yoga_studio2.jpg Correct\n", + "\n", + "pair 9:\n", + " image 1: pair_3_5_100.png\n", + " image 2: pair_3_5_33.png\n", + " img 1 brain embedding chose: pair_3_5_100.png Correct\n", + " img 2 brain embedding chose: pair_3_5_33.png Correct\n", + "\n", + "pair 10:\n", + " image 1: pair_10_14_89.png\n", + " image 2: pair_10_14_33.png\n", + " img 1 brain embedding chose: pair_10_14_89.png Correct\n", + " img 2 brain embedding chose: pair_10_14_33.png Correct\n", + "\n", + "pair 11:\n", + " image 1: pair_7_10_56.png\n", + " image 2: pair_7_10_89.png\n", + " img 1 brain embedding chose: pair_7_10_89.png Incorrect\n", + " img 2 brain embedding chose: pair_7_10_89.png Correct\n", + "\n", + "pair 12:\n", + " image 1: pair_1_3_56.png\n", + " image 2: pair_1_3_89.png\n", + " img 1 brain embedding chose: pair_1_3_56.png Correct\n", + " img 2 brain embedding chose: pair_1_3_89.png Correct\n", + "\n", + "pair 13:\n", + " image 1: pair_6_9_56.png\n", + " image 2: pair_6_9_11.png\n", + " img 1 brain embedding chose: pair_6_9_56.png Correct\n", + " img 2 brain embedding chose: pair_6_9_56.png Incorrect\n", + "\n", + "pair 14:\n", + " image 1: pair_42_w_icerink1.jpg\n", + " image 2: pair_42_w_icerink2.jpg\n", + " img 1 brain embedding chose: pair_42_w_icerink2.jpg Incorrect\n", + " img 2 brain embedding chose: pair_42_w_icerink2.jpg Correct\n", + "\n", + "pair 15:\n", + " image 1: pair_40_w_escalator1.jpg\n", + " image 2: pair_40_w_escalator2.jpg\n", + " img 1 brain embedding chose: pair_40_w_escalator2.jpg Incorrect\n", + " img 2 brain embedding chose: pair_40_w_escalator1.jpg Incorrect\n", + "\n", + "pair 16:\n", + " image 1: pair_8_12_56.png\n", + " image 2: pair_8_12_0.png\n", + " img 1 brain embedding chose: pair_8_12_56.png Correct\n", + " img 2 brain embedding chose: pair_8_12_0.png Correct\n", + "\n", + "pair 17:\n", + " image 1: pair_45_w_pagoda2.jpg\n", + " image 2: pair_45_w_pagoda1.jpg\n", + " img 1 brain embedding chose: pair_45_w_pagoda2.jpg Correct\n", + " img 2 brain embedding chose: pair_45_w_pagoda1.jpg Correct\n", + "\n", + "pair 18:\n", + " image 1: pair_38_w_arch2.jpg\n", + " image 2: pair_38_w_arch1.jpg\n", + " img 1 brain embedding chose: pair_38_w_arch1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_38_w_arch1.jpg Correct\n", + "\n", + "pair 19:\n", + " image 1: pair_47_w_roller_coaster1.jpg\n", + " image 2: pair_47_w_roller_coaster2.jpg\n", + " img 1 brain embedding chose: pair_47_w_roller_coaster1.jpg Correct\n", + " img 2 brain embedding chose: pair_47_w_roller_coaster2.jpg Correct\n", + "\n", + "pair 20:\n", + " image 1: pair_0_1_78.png\n", + " image 2: pair_0_1_22.png\n", + " img 1 brain embedding chose: pair_0_1_78.png Correct\n", + " img 2 brain embedding chose: pair_0_1_22.png Correct\n", + "\n", + "pair 21:\n", + " image 1: pair_13_17_89.png\n", + " image 2: pair_13_17_56.png\n", + " img 1 brain embedding chose: pair_13_17_89.png Correct\n", + " img 2 brain embedding chose: pair_13_17_56.png Correct\n", + "\n", + "pair 22:\n", + " image 1: pair_43_w_lighthouse2.jpg\n", + " image 2: pair_43_w_lighthouse1.jpg\n", + " img 1 brain embedding chose: pair_43_w_lighthouse2.jpg Correct\n", + " img 2 brain embedding chose: pair_43_w_lighthouse1.jpg Correct\n", + "\n", + "pair 23:\n", + " image 1: pair_16_20_56.png\n", + " image 2: pair_16_20_0.png\n", + " img 1 brain embedding chose: pair_16_20_56.png Correct\n", + " img 2 brain embedding chose: pair_16_20_0.png Correct\n", + "\n", + "pair 24:\n", + " image 1: pair_11_15_89.png\n", + " image 2: pair_11_15_44.png\n", + " img 1 brain embedding chose: pair_11_15_89.png Correct\n", + " img 2 brain embedding chose: pair_11_15_44.png Correct\n", + "\n", + "pair 25:\n", + " image 1: pair_4_6_89.png\n", + " image 2: pair_4_6_44.png\n", + " img 1 brain embedding chose: pair_4_6_44.png Incorrect\n", + " img 2 brain embedding chose: pair_4_6_44.png Correct\n", + "\n", + "pair 26:\n", + " image 1: pair_2_4_89.png\n", + " image 2: pair_2_4_22.png\n", + " img 1 brain embedding chose: pair_2_4_89.png Correct\n", + " img 2 brain embedding chose: pair_2_4_22.png Correct\n", + "\n", + "pair 27:\n", + " image 1: pair_44_w_log_cabin2.jpg\n", + " image 2: pair_44_w_log_cabin1.jpg\n", + " img 1 brain embedding chose: pair_44_w_log_cabin2.jpg Correct\n", + " img 2 brain embedding chose: pair_44_w_log_cabin1.jpg Correct\n", + "\n", + "pair 28:\n", + " image 1: pair_48_w_runway1.jpg\n", + " image 2: pair_48_w_runway2.jpg\n", + " img 1 brain embedding chose: pair_48_w_runway1.jpg Correct\n", + " img 2 brain embedding chose: pair_48_w_runway2.jpg Correct\n", + "\n", + "pair 29:\n", + " image 1: pair_9_13_44.png\n", + " image 2: pair_9_13_89.png\n", + " img 1 brain embedding chose: pair_9_13_44.png Correct\n", + " img 2 brain embedding chose: pair_9_13_89.png Correct\n", + "\n", + "pair 30:\n", + " image 1: pair_15_19_56.png\n", + " image 2: pair_15_19_0.png\n", + " img 1 brain embedding chose: pair_15_19_0.png Incorrect\n", + " img 2 brain embedding chose: pair_15_19_0.png Correct\n", + "\n", + "pair 31:\n", + " image 1: pair_41_w_gym2.jpg\n", + " image 2: pair_41_w_gym1.jpg\n", + " img 1 brain embedding chose: pair_41_w_gym2.jpg Correct\n", + " img 2 brain embedding chose: pair_41_w_gym2.jpg Incorrect\n", + "\n", + "correct choices: 49/62\n", + "mst 2afc score: 0.7903 (79.03%)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([124, 541875])\n", + "torch.Size([124, 541875])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 124/124 [00:00<00:00, 185.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pixel Correlation: 0.061674056456206466\n", + "converted, now calculating ssim...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 124/124 [00:01<00:00, 121.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SSIM: 0.3286471108485005\n", + "Loading AlexNet\n", + "\n", + "---early, AlexNet(2)---\n", + "2-way Percent Correct (early AlexNet): 0.6423\n", + "\n", + "---mid, AlexNet(5)---\n", + "2-way Percent Correct (mid AlexNet): 0.6901\n", + "Loading Inception V3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py:174: UserWarning: NOTE: The nodes obtained by tracing the model in eval mode are a subsequence of those obtained in train mode. When choosing nodes for feature extraction, you may need to specify output nodes for train and eval mode separately.\n", + " warnings.warn(msg + suggestion_msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2-way Percent Correct (Inception V3): 0.5619\n", + "Loading CLIP\n", + "2-way Percent Correct (CLIP): 0.5442\n", + "Loading EfficientNet B1\n", + "Distance EfficientNet B1: 0.9346731287052515\n", + "Loading SwAV\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using cache found in /root/.cache/torch/hub/facebookresearch_swav_main\n", + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Distance SwAV: 0.6017282310215799\n" + ] + } + ], "source": [ "# Run evaluation metrics\n", "from utils_mindeye import calculate_retrieval_metrics, calculate_alexnet, calculate_clip, calculate_swav, calculate_efficientnet_b1, calculate_inception_v3, calculate_pixcorr, calculate_ssim, deduplicate_tensors\n", @@ -1325,13 +1820,19 @@ "all_ground_truth_save_tensor = []\n", "all_retrieved_save_tensor = []\n", "\n", + "\n", + "eval_session = \"ses-03\"\n", + "eval_all_vox_image_names = []\n", + "for run_num in range(1, n_runs + 1):\n", + " tr_labels_df = pd.read_csv(f\"{data_path}/events/sub-005_{eval_session}_task-{func_task_name}_run-{run_num:02d}_tr_labels.csv\")\n", + " for label in tr_labels_df['tr_label_hrf']:\n", + " if pd.notna(label) and label not in ('blank', 'blank.jpg') and 'MST_pairs' in label:\n", + " eval_all_vox_image_names.append(label)\n", + "\n", "for run_num in range(n_runs):\n", - " save_path = f\"{derivatives_path}/{sub}_{session}_task-{func_task_name}_run-{run_num+1:02d}_recons\"\n", + " save_path = f\"{derivatives_path}/{sub}_{eval_session}_task-{func_task_name}_run-{run_num+1:02d}_recons\"\n", "\n", " try:\n", - " # recons = torch.load(os.path.join(save_path, \"all_recons.pt\")).to(torch.float16)\n", - " # clipvoxels = torch.load(os.path.join(save_path, \"all_clipvoxels.pt\")).to(torch.float16)\n", - " # ground_truth = torch.load(os.path.join(save_path, \"all_ground_truth.pt\")).to(torch.float16)\n", " recons = torch.load(os.path.join(save_path, \"all_recons.pt\")).to(torch.float16).to(device)\n", " clipvoxels = torch.load(os.path.join(save_path, \"all_clipvoxels.pt\")).to(torch.float16).to(device)\n", " ground_truth = torch.load(os.path.join(save_path, \"all_ground_truth.pt\")).to(torch.float16).to(device)\n", @@ -1342,7 +1843,6 @@ " except FileNotFoundError:\n", " print(\"Error: Tensors not found. Please check the save path.\")\n", "\n", - "# Concatenate tensors along the first dimension\n", "try:\n", " all_recons_save_tensor = torch.cat(all_recons_save_tensor, dim=0)\n", " all_clipvoxels_save_tensor = torch.cat(all_clipvoxels_save_tensor, dim=0)\n", @@ -1350,18 +1850,23 @@ "except RuntimeError:\n", " print('Error: Couldn\\'t concatenate tensors')\n", "\n", + "cache_dir = f\"{derivatives_path}/image_embeddings_cache\"\n", + "\n", "with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n", " unique_clip_voxels, unique_ground_truth, duplicated = deduplicate_tensors(all_clipvoxels_save_tensor, all_ground_truth_save_tensor)\n", " \n", " print('calculating retrieval subset 0 (first set of repeats)')\n", " unique_clip_voxels_subset0 = all_clipvoxels_save_tensor[np.array(duplicated)[:,0]]\n", " unique_ground_truth_subset0 = all_ground_truth_save_tensor[np.array(duplicated)[:,0]]\n", - " all_fwd_acc_subset0, all_bwd_acc_subset0 = calculate_retrieval_metrics(unique_clip_voxels_subset0, unique_ground_truth_subset0)\n", + " mst_names_subset0 = [eval_all_vox_image_names[i] for i in np.array(duplicated)[:,0]]\n", + " all_fwd_acc_subset0, all_bwd_acc_subset0, mst_2afc_subset0 = calculate_retrieval_metrics(unique_clip_voxels_subset0, unique_ground_truth_subset0, mst_image_names=mst_names_subset0, cache_dir=cache_dir)\n", "\n", - " print('calculating retrieval subset 1 (second set of repeats)')\n", + " print('\\n\\ncalculating retrieval subset 1 (second set of repeats)')\n", " unique_clip_voxels_subset1 = all_clipvoxels_save_tensor[np.array(duplicated)[:,1]]\n", " unique_ground_truth_subset1 = all_ground_truth_save_tensor[np.array(duplicated)[:,1]]\n", - " all_fwd_acc_subset1, all_bwd_acc_subset1 = calculate_retrieval_metrics(unique_clip_voxels_subset1, unique_ground_truth_subset1)\n", + " mst_names_subset1 = [eval_all_vox_image_names[i] for i in np.array(duplicated)[:,1]]\n", + " all_fwd_acc_subset1, all_bwd_acc_subset1, mst_2afc_subset1 = calculate_retrieval_metrics(unique_clip_voxels_subset1, unique_ground_truth_subset1, mst_image_names=mst_names_subset1, cache_dir=cache_dir)\n", + " \n", " pixcorr = calculate_pixcorr(all_recons_save_tensor, all_ground_truth_save_tensor)\n", " ssim_ = calculate_ssim(all_recons_save_tensor, all_ground_truth_save_tensor)\n", " alexnet2, alexnet5 = calculate_alexnet(all_recons_save_tensor, all_ground_truth_save_tensor)\n", @@ -1370,8 +1875,6 @@ " efficientnet = calculate_efficientnet_b1(all_recons_save_tensor, all_ground_truth_save_tensor)\n", " swav = calculate_swav(all_recons_save_tensor, all_ground_truth_save_tensor)\n", "\n", - "\n", - "# save the results to a csv file\n", "df_metrics = pd.DataFrame({\n", " \"Metric\": [\n", " \"alexnet2\",\n", @@ -1385,7 +1888,9 @@ " \"all_fwd_acc_subset0\",\n", " \"all_bwd_acc_subset0\",\n", " \"all_fwd_acc_subset1\",\n", - " \"all_bwd_acc_subset1\"\n", + " \"all_bwd_acc_subset1\",\n", + " \"mst_2afc_subset0\",\n", + " \"mst_2afc_subset1\"\n", " ],\n", " \"Value\": [\n", " alexnet2,\n", @@ -1393,25 +1898,139 @@ " inception,\n", " clip_,\n", " efficientnet,\n", - " swav,\n", + " swav, \n", " pixcorr,\n", " ssim_,\n", " all_fwd_acc_subset0,\n", " all_bwd_acc_subset0,\n", " all_fwd_acc_subset1,\n", - " all_bwd_acc_subset1\n", + " all_bwd_acc_subset1,\n", + " mst_2afc_subset0,\n", + " mst_2afc_subset1\n", " ]\n", "})" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "id": "52e71597", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
| \n", + " | Formatted | \n", + "
|---|---|
| Metric | \n", + "\n", + " |
| alexnet2 | \n", + "64.23% ↑ | \n", + "
| alexnet5 | \n", + "69.01% ↑ | \n", + "
| inception | \n", + "56.19% ↑ | \n", + "
| clip_ | \n", + "54.42% ↑ | \n", + "
| efficientnet | \n", + "0.93 ↓ | \n", + "
| swav | \n", + "0.60 ↓ | \n", + "
| pixcorr | \n", + "0.06 ↑ | \n", + "
| ssim | \n", + "0.33 ↑ | \n", + "
| all_fwd_acc_subset0 | \n", + "22.58% ↑ | \n", + "
| all_bwd_acc_subset0 | \n", + "29.03% ↑ | \n", + "
| all_fwd_acc_subset1 | \n", + "30.65% ↑ | \n", + "
| all_bwd_acc_subset1 | \n", + "37.10% ↑ | \n", + "
| mst_2afc_subset0 | \n", + "79.03% ↑ | \n", + "
| mst_2afc_subset1 | \n", + "79.03% ↑ | \n", + "