diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 555caaf9..9b57e110 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -186,17 +186,19 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: + block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(q_max_block_size, query.shape[2]), + block_q=block_size_q, block_kv_compute=min(kv_max_block_size, key.shape[2]), block_kv=min(kv_max_block_size, key.shape[2]), - block_q_dkv=min(q_max_block_size, query.shape[2]), + block_q_dkv=block_size_q, block_kv_dkv=min(kv_max_block_size, key.shape[2]), block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq, + block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, )