Skip to content

Conversation

@mgaido91
Copy link

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

The PR adds the encoder-decoder cross-attention to the output hypotheses returned by ASR models.

Collection: ASR

Changelog

  • Returns the cross-attention scores in the output of the greedy generator
  • Returns the cross-attention scores in the output of the beam search generator

Usage

  • You can potentially add a usage example below
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models.aed_multitask_models import MultiTaskTranscriptionConfig
model = ASRModel.from_pretrained(model_name="nvidia/canary-1b-v2")
config = MultiTaskTranscriptionConfig(
    batch_size=4,
    return_hypotheses=True,
    num_workers=0,
    verbose=False,
    prompt={'source_lang': 'en', 'target_lang': 'en'},
    enable_chunking=False
)
output = model.transcribe("/Users/mgaido/Downloads/vp-test/aa.wav", override_config=config)
assert output[0].xatt_scores is not None

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

@nithinraok @andrusenkoau

Additional Information

@github-actions github-actions bot added the ASR label Dec 24, 2025
@mgaido91 mgaido91 force-pushed the add_attention_to_output_hypo branch from 2de6160 to 21d5bb8 Compare December 24, 2025 14:09
Copy link
Member

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Marco. great work. Added comments. Also,
Could you add an option something like preserve_xattn_scores, so when enabled through

decoding_cfg = MultiTaskDecodingConfig(
    strategy="beam",  # or "greedy"
    preserve_xattn_scores=True,
)

only store and return xattn_scores (to save memory by default)

last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction.
xatt_scores (Optional): List of cross-attention scores for each decoder layer. Each element of the list
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn;t shape is List[BxHxT1xT2] . Also best to add: this is used with AED models

)
if xatt_scores_list is not None:
for layer in range(len(xatt_scores_list)):
xatt_scores_list[layer] = torch.cat((xatt_scores_list[layer], new_xatt_scores_list[layer]), dim=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about condition when new_xattn_scores_list is None? cat would fail

pos=0,
return_scores: bool = True,
):
log_probs, decoder_mems_list, _ = super()._one_step_forward(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you update here as well and also include in returns tuple


# select xatt scores corresponding to chosen hypotheses
if next_xatt_scores_list is not None:
num_heads = xatt_scores_list[0].shape[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for xatt_scores_list if None

return prefixes, scores * len_penalties, tgt
return prefixes, scores * len_penalties, tgt, xatt_scores_list
else:
return tgt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might also return xatt_scores_list here as return_beam_scores is independent of return xattn_scores.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants