diff --git a/.gitignore b/.gitignore index f5b2f05..d5e502e 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ package-lock.json # MacOS stuff .DS_Store .vscode +/venv \ No newline at end of file diff --git a/mindeye/scripts/mindeye.ipynb b/mindeye/scripts/mindeye.ipynb index 0845925..56191d4 100644 --- a/mindeye/scripts/mindeye.ipynb +++ b/mindeye/scripts/mindeye.ipynb @@ -2,26 +2,16 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 40, "id": "b6053a83-2259-475e-9e21-201e44217e88", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/conf/.venv/lib/python3.11/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "line 6: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n", - "line 6: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n", - "line 14: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n", - "line 14: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/scripts\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], @@ -102,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 41, "id": "699a3162", "metadata": {}, "outputs": [], @@ -117,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 42, "id": "4516e788-85cc-42ab-b05a-11bd7207f6ba", "metadata": {}, "outputs": [], @@ -140,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 43, "id": "12be1838-f387-4cdd-b7cb-217a74501359", "metadata": {}, "outputs": [ @@ -171,7 +161,7 @@ "722060552" ] }, - "execution_count": 4, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -314,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 44, "id": "8a627d35-3cd5-4cd1-9bb3-c02c0c97f7a1", "metadata": {}, "outputs": [], @@ -341,10 +331,23 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 45, "id": "05bd11f3-6d4d-4ee4-a23b-443afeb5c3fe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[45]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 14\u001b[39m sampler_config[\u001b[33m'\u001b[39m\u001b[33mparams\u001b[39m\u001b[33m'\u001b[39m][\u001b[33m'\u001b[39m\u001b[33mnum_steps\u001b[39m\u001b[33m'\u001b[39m] = \u001b[32m38\u001b[39m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstorage_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/diffusion_engine\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mrb\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m input_file:\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m diffusion_engine = \u001b[43mpickle\u001b[49m\u001b[43m.\u001b[49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_file\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[38;5;66;03m# set to inference\u001b[39;00m\n\u001b[32m 18\u001b[39m diffusion_engine.eval().requires_grad_(\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torch/storage.py:336\u001b[39m, in \u001b[36m_load_from_bytes\u001b[39m\u001b[34m(b)\u001b[39m\n\u001b[32m 332\u001b[39m \u001b[38;5;129m@_share_memory_lock_protected\u001b[39m\n\u001b[32m 333\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_share_filename_cpu_\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m 334\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()._share_filename_cpu_(*args, **kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m336\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_load_from_bytes\u001b[39m(b):\n\u001b[32m 337\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m torch.load(io.BytesIO(b))\n\u001b[32m 340\u001b[39m _StorageBase.type = _type \u001b[38;5;66;03m# type: ignore[assignment]\u001b[39;00m\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], "source": [ "# prep unCLIP\n", "config = OmegaConf.load(f\"{project_path}/models/generative_models/configs/unclip6.yaml\")\n", @@ -379,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 46, "id": "58cb6183", "metadata": {}, "outputs": [], @@ -399,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 47, "id": "48586a97", "metadata": {}, "outputs": [ @@ -408,7 +411,7 @@ "output_type": "stream", "text": [ "Data shape: (780, 109)\n", - "Using design file: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", + "Using design file: /home/amaarc/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", "Total number of images: 770\n", "Number of unique images: 126\n", "n_runs 11\n", @@ -512,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 48, "id": "51d16a03", "metadata": {}, "outputs": [ @@ -520,9 +523,9 @@ "name": "stderr", "output_type": "stream", "text": [ - " 6%|▌ | 42/693 [00:00<00:01, 412.55it/s]/home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/conf/.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n", + " 8%|▊ | 57/693 [00:00<00:05, 118.14it/s]/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n", " warnings.warn(\n", - "100%|██████████| 693/693 [00:06<00:00, 104.06it/s]" + "100%|██████████| 693/693 [00:20<00:00, 33.45it/s]" ] }, { @@ -571,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 49, "id": "f692b9cb", "metadata": {}, "outputs": [], @@ -590,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 50, "id": "fbffed78", "metadata": {}, "outputs": [ @@ -599,7 +602,7 @@ "output_type": "stream", "text": [ "Data shape: (780, 109)\n", - "Using design file: /home/ri4541@pu.win.princeton.edu/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", + "Using design file: /home/amaarc/rtcloud-projects/mindeye/3t/data/events/csv/sub-005_ses-06.csv\n", "Total number of images: 770\n", "Number of unique images: 126\n" ] @@ -624,7 +627,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 51, "id": "31b2474d", "metadata": {}, "outputs": [ @@ -647,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 52, "id": "e05736bc-c816-49ae-8718-b6c31b412781", "metadata": {}, "outputs": [], @@ -676,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 53, "id": "b21c9550", "metadata": {}, "outputs": [], @@ -688,7 +691,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 54, "id": "4768e42a", "metadata": {}, "outputs": [], @@ -719,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 55, "id": "057b3dd3", "metadata": {}, "outputs": [ @@ -739,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 56, "id": "895d9228-46e0-4ec0-9fe4-00f802f9708f", "metadata": {}, "outputs": [], @@ -834,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 57, "id": "4e595bc9", "metadata": {}, "outputs": [], @@ -847,7 +850,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 58, "id": "ff1807bc", "metadata": {}, "outputs": [], @@ -859,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 59, "id": "211a4d40-643d-493b-874b-2030490b9bf4", "metadata": { "tags": [] @@ -1313,10 +1316,502 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 66, "id": "d6675825", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calculating retrieval subset 0 (first set of repeats)\n", + "Loading clip_img_embedder\n", + "The total pool of images and clip voxels to do retrieval on is: 62\n", + "loading cached embeddings from /home/amaarc/rtcloud-projects/mindeye/3t/derivatives/image_embeddings_cache/image_embeddings_62.pt\n", + "calculating retrieval metrics\n", + "overall fwd percent_correct: 0.2258\n", + "overall bwd percent_correct: 0.2903\n", + "\n", + "pair 1:\n", + " image 1: pair_46_w_pool2.jpg\n", + " image 2: pair_46_w_pool1.jpg\n", + " img 1 brain embedding chose: pair_46_w_pool2.jpg Correct\n", + " img 2 brain embedding chose: pair_46_w_pool1.jpg Correct\n", + "\n", + "pair 2:\n", + " image 1: pair_14_18_78.png\n", + " image 2: pair_14_18_22.png\n", + " img 1 brain embedding chose: pair_14_18_22.png Incorrect\n", + " img 2 brain embedding chose: pair_14_18_22.png Correct\n", + "\n", + "pair 3:\n", + " image 1: pair_39_w_corridor1.jpg\n", + " image 2: pair_39_w_corridor2.jpg\n", + " img 1 brain embedding chose: pair_39_w_corridor1.jpg Correct\n", + " img 2 brain embedding chose: pair_39_w_corridor2.jpg Correct\n", + "\n", + "pair 4:\n", + " image 1: pair_5_8_0.png\n", + " image 2: pair_5_8_78.png\n", + " img 1 brain embedding chose: pair_5_8_0.png Correct\n", + " img 2 brain embedding chose: pair_5_8_78.png Correct\n", + "\n", + "pair 5:\n", + " image 1: pair_37_w_airplane_interior2.jpg\n", + " image 2: pair_37_w_airplane_interior1.jpg\n", + " img 1 brain embedding chose: pair_37_w_airplane_interior1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_37_w_airplane_interior2.jpg Incorrect\n", + "\n", + "pair 6:\n", + " image 1: pair_17_21_33.png\n", + " image 2: pair_17_21_0.png\n", + " img 1 brain embedding chose: pair_17_21_33.png Correct\n", + " img 2 brain embedding chose: pair_17_21_0.png Correct\n", + "\n", + "pair 7:\n", + " image 1: pair_12_16_22.png\n", + " image 2: pair_12_16_78.png\n", + " img 1 brain embedding chose: pair_12_16_22.png Correct\n", + " img 2 brain embedding chose: pair_12_16_78.png Correct\n", + "\n", + "pair 8:\n", + " image 1: pair_49_w_yoga_studio1.jpg\n", + " image 2: pair_49_w_yoga_studio2.jpg\n", + " img 1 brain embedding chose: pair_49_w_yoga_studio1.jpg Correct\n", + " img 2 brain embedding chose: pair_49_w_yoga_studio1.jpg Incorrect\n", + "\n", + "pair 9:\n", + " image 1: pair_3_5_100.png\n", + " image 2: pair_3_5_33.png\n", + " img 1 brain embedding chose: pair_3_5_100.png Correct\n", + " img 2 brain embedding chose: pair_3_5_33.png Correct\n", + "\n", + "pair 10:\n", + " image 1: pair_10_14_89.png\n", + " image 2: pair_10_14_33.png\n", + " img 1 brain embedding chose: pair_10_14_89.png Correct\n", + " img 2 brain embedding chose: pair_10_14_33.png Correct\n", + "\n", + "pair 11:\n", + " image 1: pair_7_10_56.png\n", + " image 2: pair_7_10_89.png\n", + " img 1 brain embedding chose: pair_7_10_89.png Incorrect\n", + " img 2 brain embedding chose: pair_7_10_89.png Correct\n", + "\n", + "pair 12:\n", + " image 1: pair_1_3_56.png\n", + " image 2: pair_1_3_89.png\n", + " img 1 brain embedding chose: pair_1_3_89.png Incorrect\n", + " img 2 brain embedding chose: pair_1_3_89.png Correct\n", + "\n", + "pair 13:\n", + " image 1: pair_6_9_56.png\n", + " image 2: pair_6_9_11.png\n", + " img 1 brain embedding chose: pair_6_9_56.png Correct\n", + " img 2 brain embedding chose: pair_6_9_56.png Incorrect\n", + "\n", + "pair 14:\n", + " image 1: pair_42_w_icerink1.jpg\n", + " image 2: pair_42_w_icerink2.jpg\n", + " img 1 brain embedding chose: pair_42_w_icerink1.jpg Correct\n", + " img 2 brain embedding chose: pair_42_w_icerink2.jpg Correct\n", + "\n", + "pair 15:\n", + " image 1: pair_40_w_escalator1.jpg\n", + " image 2: pair_40_w_escalator2.jpg\n", + " img 1 brain embedding chose: pair_40_w_escalator1.jpg Correct\n", + " img 2 brain embedding chose: pair_40_w_escalator2.jpg Correct\n", + "\n", + "pair 16:\n", + " image 1: pair_8_12_56.png\n", + " image 2: pair_8_12_0.png\n", + " img 1 brain embedding chose: pair_8_12_56.png Correct\n", + " img 2 brain embedding chose: pair_8_12_0.png Correct\n", + "\n", + "pair 17:\n", + " image 1: pair_45_w_pagoda2.jpg\n", + " image 2: pair_45_w_pagoda1.jpg\n", + " img 1 brain embedding chose: pair_45_w_pagoda1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_45_w_pagoda1.jpg Correct\n", + "\n", + "pair 18:\n", + " image 1: pair_38_w_arch2.jpg\n", + " image 2: pair_38_w_arch1.jpg\n", + " img 1 brain embedding chose: pair_38_w_arch2.jpg Correct\n", + " img 2 brain embedding chose: pair_38_w_arch1.jpg Correct\n", + "\n", + "pair 19:\n", + " image 1: pair_47_w_roller_coaster1.jpg\n", + " image 2: pair_47_w_roller_coaster2.jpg\n", + " img 1 brain embedding chose: pair_47_w_roller_coaster1.jpg Correct\n", + " img 2 brain embedding chose: pair_47_w_roller_coaster2.jpg Correct\n", + "\n", + "pair 20:\n", + " image 1: pair_0_1_78.png\n", + " image 2: pair_0_1_22.png\n", + " img 1 brain embedding chose: pair_0_1_78.png Correct\n", + " img 2 brain embedding chose: pair_0_1_22.png Correct\n", + "\n", + "pair 21:\n", + " image 1: pair_13_17_89.png\n", + " image 2: pair_13_17_56.png\n", + " img 1 brain embedding chose: pair_13_17_89.png Correct\n", + " img 2 brain embedding chose: pair_13_17_56.png Correct\n", + "\n", + "pair 22:\n", + " image 1: pair_43_w_lighthouse2.jpg\n", + " image 2: pair_43_w_lighthouse1.jpg\n", + " img 1 brain embedding chose: pair_43_w_lighthouse2.jpg Correct\n", + " img 2 brain embedding chose: pair_43_w_lighthouse1.jpg Correct\n", + "\n", + "pair 23:\n", + " image 1: pair_16_20_56.png\n", + " image 2: pair_16_20_0.png\n", + " img 1 brain embedding chose: pair_16_20_56.png Correct\n", + " img 2 brain embedding chose: pair_16_20_56.png Incorrect\n", + "\n", + "pair 24:\n", + " image 1: pair_11_15_89.png\n", + " image 2: pair_11_15_44.png\n", + " img 1 brain embedding chose: pair_11_15_44.png Incorrect\n", + " img 2 brain embedding chose: pair_11_15_44.png Correct\n", + "\n", + "pair 25:\n", + " image 1: pair_4_6_89.png\n", + " image 2: pair_4_6_44.png\n", + " img 1 brain embedding chose: pair_4_6_44.png Incorrect\n", + " img 2 brain embedding chose: pair_4_6_44.png Correct\n", + "\n", + "pair 26:\n", + " image 1: pair_2_4_89.png\n", + " image 2: pair_2_4_22.png\n", + " img 1 brain embedding chose: pair_2_4_89.png Correct\n", + " img 2 brain embedding chose: pair_2_4_22.png Correct\n", + "\n", + "pair 27:\n", + " image 1: pair_44_w_log_cabin2.jpg\n", + " image 2: pair_44_w_log_cabin1.jpg\n", + " img 1 brain embedding chose: pair_44_w_log_cabin2.jpg Correct\n", + " img 2 brain embedding chose: pair_44_w_log_cabin1.jpg Correct\n", + "\n", + "pair 28:\n", + " image 1: pair_48_w_runway1.jpg\n", + " image 2: pair_48_w_runway2.jpg\n", + " img 1 brain embedding chose: pair_48_w_runway1.jpg Correct\n", + " img 2 brain embedding chose: pair_48_w_runway2.jpg Correct\n", + "\n", + "pair 29:\n", + " image 1: pair_9_13_44.png\n", + " image 2: pair_9_13_89.png\n", + " img 1 brain embedding chose: pair_9_13_44.png Correct\n", + " img 2 brain embedding chose: pair_9_13_44.png Incorrect\n", + "\n", + "pair 30:\n", + " image 1: pair_15_19_56.png\n", + " image 2: pair_15_19_0.png\n", + " img 1 brain embedding chose: pair_15_19_56.png Correct\n", + " img 2 brain embedding chose: pair_15_19_0.png Correct\n", + "\n", + "pair 31:\n", + " image 1: pair_41_w_gym2.jpg\n", + " image 2: pair_41_w_gym1.jpg\n", + " img 1 brain embedding chose: pair_41_w_gym2.jpg Correct\n", + " img 2 brain embedding chose: pair_41_w_gym2.jpg Incorrect\n", + "\n", + "correct choices: 49/62\n", + "mst 2afc score: 0.7903 (79.03%)\n", + "\n", + "\n", + "calculating retrieval subset 1 (second set of repeats)\n", + "Loading clip_img_embedder\n", + "The total pool of images and clip voxels to do retrieval on is: 62\n", + "loading cached embeddings from /home/amaarc/rtcloud-projects/mindeye/3t/derivatives/image_embeddings_cache/image_embeddings_62.pt\n", + "calculating retrieval metrics\n", + "overall fwd percent_correct: 0.3065\n", + "overall bwd percent_correct: 0.3710\n", + "\n", + "pair 1:\n", + " image 1: pair_46_w_pool2.jpg\n", + " image 2: pair_46_w_pool1.jpg\n", + " img 1 brain embedding chose: pair_46_w_pool1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_46_w_pool1.jpg Correct\n", + "\n", + "pair 2:\n", + " image 1: pair_14_18_78.png\n", + " image 2: pair_14_18_22.png\n", + " img 1 brain embedding chose: pair_14_18_78.png Correct\n", + " img 2 brain embedding chose: pair_14_18_78.png Incorrect\n", + "\n", + "pair 3:\n", + " image 1: pair_39_w_corridor1.jpg\n", + " image 2: pair_39_w_corridor2.jpg\n", + " img 1 brain embedding chose: pair_39_w_corridor1.jpg Correct\n", + " img 2 brain embedding chose: pair_39_w_corridor2.jpg Correct\n", + "\n", + "pair 4:\n", + " image 1: pair_5_8_0.png\n", + " image 2: pair_5_8_78.png\n", + " img 1 brain embedding chose: pair_5_8_0.png Correct\n", + " img 2 brain embedding chose: pair_5_8_0.png Incorrect\n", + "\n", + "pair 5:\n", + " image 1: pair_37_w_airplane_interior2.jpg\n", + " image 2: pair_37_w_airplane_interior1.jpg\n", + " img 1 brain embedding chose: pair_37_w_airplane_interior2.jpg Correct\n", + " img 2 brain embedding chose: pair_37_w_airplane_interior1.jpg Correct\n", + "\n", + "pair 6:\n", + " image 1: pair_17_21_33.png\n", + " image 2: pair_17_21_0.png\n", + " img 1 brain embedding chose: pair_17_21_33.png Correct\n", + " img 2 brain embedding chose: pair_17_21_33.png Incorrect\n", + "\n", + "pair 7:\n", + " image 1: pair_12_16_22.png\n", + " image 2: pair_12_16_78.png\n", + " img 1 brain embedding chose: pair_12_16_22.png Correct\n", + " img 2 brain embedding chose: pair_12_16_78.png Correct\n", + "\n", + "pair 8:\n", + " image 1: pair_49_w_yoga_studio1.jpg\n", + " image 2: pair_49_w_yoga_studio2.jpg\n", + " img 1 brain embedding chose: pair_49_w_yoga_studio1.jpg Correct\n", + " img 2 brain embedding chose: pair_49_w_yoga_studio2.jpg Correct\n", + "\n", + "pair 9:\n", + " image 1: pair_3_5_100.png\n", + " image 2: pair_3_5_33.png\n", + " img 1 brain embedding chose: pair_3_5_100.png Correct\n", + " img 2 brain embedding chose: pair_3_5_33.png Correct\n", + "\n", + "pair 10:\n", + " image 1: pair_10_14_89.png\n", + " image 2: pair_10_14_33.png\n", + " img 1 brain embedding chose: pair_10_14_89.png Correct\n", + " img 2 brain embedding chose: pair_10_14_33.png Correct\n", + "\n", + "pair 11:\n", + " image 1: pair_7_10_56.png\n", + " image 2: pair_7_10_89.png\n", + " img 1 brain embedding chose: pair_7_10_89.png Incorrect\n", + " img 2 brain embedding chose: pair_7_10_89.png Correct\n", + "\n", + "pair 12:\n", + " image 1: pair_1_3_56.png\n", + " image 2: pair_1_3_89.png\n", + " img 1 brain embedding chose: pair_1_3_56.png Correct\n", + " img 2 brain embedding chose: pair_1_3_89.png Correct\n", + "\n", + "pair 13:\n", + " image 1: pair_6_9_56.png\n", + " image 2: pair_6_9_11.png\n", + " img 1 brain embedding chose: pair_6_9_56.png Correct\n", + " img 2 brain embedding chose: pair_6_9_56.png Incorrect\n", + "\n", + "pair 14:\n", + " image 1: pair_42_w_icerink1.jpg\n", + " image 2: pair_42_w_icerink2.jpg\n", + " img 1 brain embedding chose: pair_42_w_icerink2.jpg Incorrect\n", + " img 2 brain embedding chose: pair_42_w_icerink2.jpg Correct\n", + "\n", + "pair 15:\n", + " image 1: pair_40_w_escalator1.jpg\n", + " image 2: pair_40_w_escalator2.jpg\n", + " img 1 brain embedding chose: pair_40_w_escalator2.jpg Incorrect\n", + " img 2 brain embedding chose: pair_40_w_escalator1.jpg Incorrect\n", + "\n", + "pair 16:\n", + " image 1: pair_8_12_56.png\n", + " image 2: pair_8_12_0.png\n", + " img 1 brain embedding chose: pair_8_12_56.png Correct\n", + " img 2 brain embedding chose: pair_8_12_0.png Correct\n", + "\n", + "pair 17:\n", + " image 1: pair_45_w_pagoda2.jpg\n", + " image 2: pair_45_w_pagoda1.jpg\n", + " img 1 brain embedding chose: pair_45_w_pagoda2.jpg Correct\n", + " img 2 brain embedding chose: pair_45_w_pagoda1.jpg Correct\n", + "\n", + "pair 18:\n", + " image 1: pair_38_w_arch2.jpg\n", + " image 2: pair_38_w_arch1.jpg\n", + " img 1 brain embedding chose: pair_38_w_arch1.jpg Incorrect\n", + " img 2 brain embedding chose: pair_38_w_arch1.jpg Correct\n", + "\n", + "pair 19:\n", + " image 1: pair_47_w_roller_coaster1.jpg\n", + " image 2: pair_47_w_roller_coaster2.jpg\n", + " img 1 brain embedding chose: pair_47_w_roller_coaster1.jpg Correct\n", + " img 2 brain embedding chose: pair_47_w_roller_coaster2.jpg Correct\n", + "\n", + "pair 20:\n", + " image 1: pair_0_1_78.png\n", + " image 2: pair_0_1_22.png\n", + " img 1 brain embedding chose: pair_0_1_78.png Correct\n", + " img 2 brain embedding chose: pair_0_1_22.png Correct\n", + "\n", + "pair 21:\n", + " image 1: pair_13_17_89.png\n", + " image 2: pair_13_17_56.png\n", + " img 1 brain embedding chose: pair_13_17_89.png Correct\n", + " img 2 brain embedding chose: pair_13_17_56.png Correct\n", + "\n", + "pair 22:\n", + " image 1: pair_43_w_lighthouse2.jpg\n", + " image 2: pair_43_w_lighthouse1.jpg\n", + " img 1 brain embedding chose: pair_43_w_lighthouse2.jpg Correct\n", + " img 2 brain embedding chose: pair_43_w_lighthouse1.jpg Correct\n", + "\n", + "pair 23:\n", + " image 1: pair_16_20_56.png\n", + " image 2: pair_16_20_0.png\n", + " img 1 brain embedding chose: pair_16_20_56.png Correct\n", + " img 2 brain embedding chose: pair_16_20_0.png Correct\n", + "\n", + "pair 24:\n", + " image 1: pair_11_15_89.png\n", + " image 2: pair_11_15_44.png\n", + " img 1 brain embedding chose: pair_11_15_89.png Correct\n", + " img 2 brain embedding chose: pair_11_15_44.png Correct\n", + "\n", + "pair 25:\n", + " image 1: pair_4_6_89.png\n", + " image 2: pair_4_6_44.png\n", + " img 1 brain embedding chose: pair_4_6_44.png Incorrect\n", + " img 2 brain embedding chose: pair_4_6_44.png Correct\n", + "\n", + "pair 26:\n", + " image 1: pair_2_4_89.png\n", + " image 2: pair_2_4_22.png\n", + " img 1 brain embedding chose: pair_2_4_89.png Correct\n", + " img 2 brain embedding chose: pair_2_4_22.png Correct\n", + "\n", + "pair 27:\n", + " image 1: pair_44_w_log_cabin2.jpg\n", + " image 2: pair_44_w_log_cabin1.jpg\n", + " img 1 brain embedding chose: pair_44_w_log_cabin2.jpg Correct\n", + " img 2 brain embedding chose: pair_44_w_log_cabin1.jpg Correct\n", + "\n", + "pair 28:\n", + " image 1: pair_48_w_runway1.jpg\n", + " image 2: pair_48_w_runway2.jpg\n", + " img 1 brain embedding chose: pair_48_w_runway1.jpg Correct\n", + " img 2 brain embedding chose: pair_48_w_runway2.jpg Correct\n", + "\n", + "pair 29:\n", + " image 1: pair_9_13_44.png\n", + " image 2: pair_9_13_89.png\n", + " img 1 brain embedding chose: pair_9_13_44.png Correct\n", + " img 2 brain embedding chose: pair_9_13_89.png Correct\n", + "\n", + "pair 30:\n", + " image 1: pair_15_19_56.png\n", + " image 2: pair_15_19_0.png\n", + " img 1 brain embedding chose: pair_15_19_0.png Incorrect\n", + " img 2 brain embedding chose: pair_15_19_0.png Correct\n", + "\n", + "pair 31:\n", + " image 1: pair_41_w_gym2.jpg\n", + " image 2: pair_41_w_gym1.jpg\n", + " img 1 brain embedding chose: pair_41_w_gym2.jpg Correct\n", + " img 2 brain embedding chose: pair_41_w_gym2.jpg Incorrect\n", + "\n", + "correct choices: 49/62\n", + "mst 2afc score: 0.7903 (79.03%)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([124, 541875])\n", + "torch.Size([124, 541875])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 124/124 [00:00<00:00, 185.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pixel Correlation: 0.061674056456206466\n", + "converted, now calculating ssim...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 124/124 [00:01<00:00, 121.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SSIM: 0.3286471108485005\n", + "Loading AlexNet\n", + "\n", + "---early, AlexNet(2)---\n", + "2-way Percent Correct (early AlexNet): 0.6423\n", + "\n", + "---mid, AlexNet(5)---\n", + "2-way Percent Correct (mid AlexNet): 0.6901\n", + "Loading Inception V3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py:174: UserWarning: NOTE: The nodes obtained by tracing the model in eval mode are a subsequence of those obtained in train mode. When choosing nodes for feature extraction, you may need to specify output nodes for train and eval mode separately.\n", + " warnings.warn(msg + suggestion_msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2-way Percent Correct (Inception V3): 0.5619\n", + "Loading CLIP\n", + "2-way Percent Correct (CLIP): 0.5442\n", + "Loading EfficientNet B1\n", + "Distance EfficientNet B1: 0.9346731287052515\n", + "Loading SwAV\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using cache found in /root/.cache/torch/hub/facebookresearch_swav_main\n", + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/amaarc/rtcloud-projects/venv/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Distance SwAV: 0.6017282310215799\n" + ] + } + ], "source": [ "# Run evaluation metrics\n", "from utils_mindeye import calculate_retrieval_metrics, calculate_alexnet, calculate_clip, calculate_swav, calculate_efficientnet_b1, calculate_inception_v3, calculate_pixcorr, calculate_ssim, deduplicate_tensors\n", @@ -1325,13 +1820,19 @@ "all_ground_truth_save_tensor = []\n", "all_retrieved_save_tensor = []\n", "\n", + "\n", + "eval_session = \"ses-03\"\n", + "eval_all_vox_image_names = []\n", + "for run_num in range(1, n_runs + 1):\n", + " tr_labels_df = pd.read_csv(f\"{data_path}/events/sub-005_{eval_session}_task-{func_task_name}_run-{run_num:02d}_tr_labels.csv\")\n", + " for label in tr_labels_df['tr_label_hrf']:\n", + " if pd.notna(label) and label not in ('blank', 'blank.jpg') and 'MST_pairs' in label:\n", + " eval_all_vox_image_names.append(label)\n", + "\n", "for run_num in range(n_runs):\n", - " save_path = f\"{derivatives_path}/{sub}_{session}_task-{func_task_name}_run-{run_num+1:02d}_recons\"\n", + " save_path = f\"{derivatives_path}/{sub}_{eval_session}_task-{func_task_name}_run-{run_num+1:02d}_recons\"\n", "\n", " try:\n", - " # recons = torch.load(os.path.join(save_path, \"all_recons.pt\")).to(torch.float16)\n", - " # clipvoxels = torch.load(os.path.join(save_path, \"all_clipvoxels.pt\")).to(torch.float16)\n", - " # ground_truth = torch.load(os.path.join(save_path, \"all_ground_truth.pt\")).to(torch.float16)\n", " recons = torch.load(os.path.join(save_path, \"all_recons.pt\")).to(torch.float16).to(device)\n", " clipvoxels = torch.load(os.path.join(save_path, \"all_clipvoxels.pt\")).to(torch.float16).to(device)\n", " ground_truth = torch.load(os.path.join(save_path, \"all_ground_truth.pt\")).to(torch.float16).to(device)\n", @@ -1342,7 +1843,6 @@ " except FileNotFoundError:\n", " print(\"Error: Tensors not found. Please check the save path.\")\n", "\n", - "# Concatenate tensors along the first dimension\n", "try:\n", " all_recons_save_tensor = torch.cat(all_recons_save_tensor, dim=0)\n", " all_clipvoxels_save_tensor = torch.cat(all_clipvoxels_save_tensor, dim=0)\n", @@ -1350,18 +1850,23 @@ "except RuntimeError:\n", " print('Error: Couldn\\'t concatenate tensors')\n", "\n", + "cache_dir = f\"{derivatives_path}/image_embeddings_cache\"\n", + "\n", "with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n", " unique_clip_voxels, unique_ground_truth, duplicated = deduplicate_tensors(all_clipvoxels_save_tensor, all_ground_truth_save_tensor)\n", " \n", " print('calculating retrieval subset 0 (first set of repeats)')\n", " unique_clip_voxels_subset0 = all_clipvoxels_save_tensor[np.array(duplicated)[:,0]]\n", " unique_ground_truth_subset0 = all_ground_truth_save_tensor[np.array(duplicated)[:,0]]\n", - " all_fwd_acc_subset0, all_bwd_acc_subset0 = calculate_retrieval_metrics(unique_clip_voxels_subset0, unique_ground_truth_subset0)\n", + " mst_names_subset0 = [eval_all_vox_image_names[i] for i in np.array(duplicated)[:,0]]\n", + " all_fwd_acc_subset0, all_bwd_acc_subset0, mst_2afc_subset0 = calculate_retrieval_metrics(unique_clip_voxels_subset0, unique_ground_truth_subset0, mst_image_names=mst_names_subset0, cache_dir=cache_dir)\n", "\n", - " print('calculating retrieval subset 1 (second set of repeats)')\n", + " print('\\n\\ncalculating retrieval subset 1 (second set of repeats)')\n", " unique_clip_voxels_subset1 = all_clipvoxels_save_tensor[np.array(duplicated)[:,1]]\n", " unique_ground_truth_subset1 = all_ground_truth_save_tensor[np.array(duplicated)[:,1]]\n", - " all_fwd_acc_subset1, all_bwd_acc_subset1 = calculate_retrieval_metrics(unique_clip_voxels_subset1, unique_ground_truth_subset1)\n", + " mst_names_subset1 = [eval_all_vox_image_names[i] for i in np.array(duplicated)[:,1]]\n", + " all_fwd_acc_subset1, all_bwd_acc_subset1, mst_2afc_subset1 = calculate_retrieval_metrics(unique_clip_voxels_subset1, unique_ground_truth_subset1, mst_image_names=mst_names_subset1, cache_dir=cache_dir)\n", + " \n", " pixcorr = calculate_pixcorr(all_recons_save_tensor, all_ground_truth_save_tensor)\n", " ssim_ = calculate_ssim(all_recons_save_tensor, all_ground_truth_save_tensor)\n", " alexnet2, alexnet5 = calculate_alexnet(all_recons_save_tensor, all_ground_truth_save_tensor)\n", @@ -1370,8 +1875,6 @@ " efficientnet = calculate_efficientnet_b1(all_recons_save_tensor, all_ground_truth_save_tensor)\n", " swav = calculate_swav(all_recons_save_tensor, all_ground_truth_save_tensor)\n", "\n", - "\n", - "# save the results to a csv file\n", "df_metrics = pd.DataFrame({\n", " \"Metric\": [\n", " \"alexnet2\",\n", @@ -1385,7 +1888,9 @@ " \"all_fwd_acc_subset0\",\n", " \"all_bwd_acc_subset0\",\n", " \"all_fwd_acc_subset1\",\n", - " \"all_bwd_acc_subset1\"\n", + " \"all_bwd_acc_subset1\",\n", + " \"mst_2afc_subset0\",\n", + " \"mst_2afc_subset1\"\n", " ],\n", " \"Value\": [\n", " alexnet2,\n", @@ -1393,25 +1898,139 @@ " inception,\n", " clip_,\n", " efficientnet,\n", - " swav,\n", + " swav, \n", " pixcorr,\n", " ssim_,\n", " all_fwd_acc_subset0,\n", " all_bwd_acc_subset0,\n", " all_fwd_acc_subset1,\n", - " all_bwd_acc_subset1\n", + " all_bwd_acc_subset1,\n", + " mst_2afc_subset0,\n", + " mst_2afc_subset1\n", " ]\n", "})" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "id": "52e71597", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Formatted
Metric
alexnet264.23% ↑
alexnet569.01% ↑
inception56.19% ↑
clip_54.42% ↑
efficientnet0.93 ↓
swav0.60 ↓
pixcorr0.06 ↑
ssim0.33 ↑
all_fwd_acc_subset022.58% ↑
all_bwd_acc_subset029.03% ↑
all_fwd_acc_subset130.65% ↑
all_bwd_acc_subset137.10% ↑
mst_2afc_subset079.03% ↑
mst_2afc_subset179.03% ↑
\n", + "
" + ], + "text/plain": [ + " Formatted\n", + "Metric \n", + "alexnet2 64.23% ↑\n", + "alexnet5 69.01% ↑\n", + "inception 56.19% ↑\n", + "clip_ 54.42% ↑\n", + "efficientnet 0.93 ↓\n", + "swav 0.60 ↓\n", + "pixcorr 0.06 ↑\n", + "ssim 0.33 ↑\n", + "all_fwd_acc_subset0 22.58% ↑\n", + "all_bwd_acc_subset0 29.03% ↑\n", + "all_fwd_acc_subset1 30.65% ↑\n", + "all_bwd_acc_subset1 37.10% ↑\n", + "mst_2afc_subset0 79.03% ↑\n", + "mst_2afc_subset1 79.03% ↑" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "percentage_metrics = [\"alexnet2\", \"alexnet5\", \"inception\", \"clip_\", \"all_fwd_acc_subset0\", \"all_bwd_acc_subset0\", \"all_bwd_acc_subset1\", \"all_fwd_acc_subset1\"]\n", + "percentage_metrics = [\"alexnet2\", \"alexnet5\", \"inception\", \"clip_\", \"all_fwd_acc_subset0\", \"all_bwd_acc_subset0\", \"all_bwd_acc_subset1\", \"all_fwd_acc_subset1\", \"mst_2afc_subset0\", \"mst_2afc_subset1\"]\n", "lower_better_metrics = [\"efficientnet\", \"swav\"]\n", "higher_better_arrow = \"↑\"\n", "lower_better_arrow = \"↓\"\n", diff --git a/mindeye/scripts/utils_mindeye.py b/mindeye/scripts/utils_mindeye.py index 0f5e1b9..d549109 100644 --- a/mindeye/scripts/utils_mindeye.py +++ b/mindeye/scripts/utils_mindeye.py @@ -506,7 +506,7 @@ def create_design_matrix(images, starts, is_new_run, unique_images, n_runs, n_tr from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder from scipy import stats from tqdm import tqdm -def calculate_retrieval_metrics(all_clip_voxels, all_images): +def calculate_retrieval_metrics(all_clip_voxels, all_images, mst_image_names=None, cache_dir=None): print("Loading clip_img_embedder") try: print(clip_img_embedder) @@ -524,23 +524,36 @@ def calculate_retrieval_metrics(all_clip_voxels, all_images): all_fwd_acc = [] all_bwd_acc = [] - assert len(all_images) == len(all_clip_voxels) + assert len(all_images) == len(all_clip_voxels) print("The total pool of images and clip voxels to do retrieval on is: ", len(all_images)) all_percent_correct_fwds, all_percent_correct_bwds = [], [] with torch.cuda.amp.autocast(dtype=torch.float16): - print("Creating embeddings for images") - with torch.no_grad(): - all_emb = clip_img_embedder(all_images.to(torch.float16).to(device)).float() # CLIP-Image + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + cache_file = os.path.join(cache_dir, f"image_embeddings_{len(all_images)}.pt") + if os.path.exists(cache_file): + print(f"loading cached embeddings from {cache_file}") + all_emb = torch.load(cache_file).to(device) + else: + print("creating embeddings for images") + with torch.no_grad(): + all_emb = clip_img_embedder(all_images.to(torch.float16).to(device)).float() + torch.save(all_emb.cpu(), cache_file) + all_emb = all_emb.to(device) + else: + print("creating embeddings for images") + with torch.no_grad(): + all_emb = clip_img_embedder(all_images.to(torch.float16).to(device)).float() # CLIP-Image - all_emb_ = all_clip_voxels # CLIP-Brain + all_emb_ = all_clip_voxels.detach().clone() - print("Calculating retrieval metrics") + print("calculating retrieval metrics") # flatten if necessary all_emb = all_emb.reshape(len(all_emb),-1).to(device) - all_emb_ = all_emb_.reshape(len(all_emb_),-1).to(device) + all_emb_ = all_emb_.reshape(len(all_emb_),-1).to(device).contiguous() - # l2norm + # l2norm all_emb = nn.functional.normalize(all_emb,dim=-1) all_emb_ = nn.functional.normalize(all_emb_,dim=-1) @@ -554,7 +567,7 @@ def calculate_retrieval_metrics(all_clip_voxels, all_images): # else: # assert len(all_fwd_sim) == 50 # assert len(all_bwd_sim) == 50 - + all_percent_correct_fwds = topk(all_fwd_sim, all_labels, k=1).item() all_percent_correct_bwds = topk(all_bwd_sim, all_labels, k=1).item() @@ -567,7 +580,86 @@ def calculate_retrieval_metrics(all_clip_voxels, all_images): print(f"overall fwd percent_correct: {all_fwd_acc[0]:.4f}") print(f"overall bwd percent_correct: {all_bwd_acc[0]:.4f}") - return all_fwd_acc[0], all_bwd_acc[0] + mst_2afc_score = None + if mst_image_names is not None: + import re + from collections import defaultdict + + if len(mst_image_names) != len(set(mst_image_names)): + return + + pair_groups = defaultdict(list) + for idx, name in enumerate(mst_image_names): + match1 = re.search(r'pair_(\d+)_(\d+)_\d+\.png', name) + match2 = re.search(r'pair_(\d+)_w_.*?[12]\.jpg', name) + if match1: + pair_id = f"{match1.group(1)}_{match1.group(2)}" + elif match2: + pair_id = f"w_{match2.group(1)}" + else: + continue + pair_groups[pair_id].append((name, idx)) + + pairs = [] + pair_names = [] + for pair_id, image_list in pair_groups.items(): + + seen_names = set() + unique_images = {} + for name, idx in image_list: + if name not in seen_names: + unique_images[name] = idx + seen_names.add(name) + + if len(unique_images) == 2: + names = list(unique_images.keys()) + idx1, idx2 = list(unique_images.values()) + pairs.append([idx1, idx2]) + pair_names.append((names[0], names[1])) + + # check for incomplete pairs + incomplete_pairs = {k: v for k, v in pair_groups.items() + if len({name for name, _ in v}) != 2} + if incomplete_pairs: + print(f"W{len(incomplete_pairs)} incomplete MST pairs") + + if len(pairs) > 0: + correct = 0 + debug_results = [] + for pair_idx, (pair, (name1, name2)) in enumerate(zip(pairs, pair_names)): + + idx1, idx2 = pair[0], pair[1] + + brain_1 = all_emb_[idx1:idx1+1] + brain_2 = all_emb_[idx2:idx2+1] + image_1 = all_emb[idx1:idx1+1] + image_2 = all_emb[idx2:idx2+1] + + sim_1_to_1 = nn.functional.cosine_similarity(brain_1, image_1) + sim_1_to_2 = nn.functional.cosine_similarity(brain_1, image_2) + sim_2_to_2 = nn.functional.cosine_similarity(brain_2, image_2) + sim_2_to_1 = nn.functional.cosine_similarity(brain_2, image_1) + + correct_1 = sim_1_to_1 > sim_1_to_2 + correct_2 = sim_2_to_2 > sim_2_to_1 + + # print comparison details + print(f"\npair {pair_idx + 1}:") + print(f" image 1: {name1.split('/')[-1]}") + print(f" image 2: {name2.split('/')[-1]}") + print(f" img 1 brain embedding chose: {name1.split('/')[-1] if correct_1 else name2.split('/')[-1]} {'Correct' if correct_1 else 'Incorrect'}") + print(f" img 2 brain embedding chose: {name2.split('/')[-1] if correct_2 else name1.split('/')[-1]} {'Correct' if correct_2 else 'Incorrect'}") + + if correct_1: + correct += 1 + if correct_2: + correct += 1 + + mst_2afc_score = correct / (2 * len(pairs)) + print(f"\ncorrect choices: {correct}/{2*len(pairs)}") + print(f"mst 2afc score: {mst_2afc_score:.4f} ({mst_2afc_score*100:.2f}%)") + + return all_fwd_acc[0], all_bwd_acc[0], mst_2afc_score from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names