-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Add cross-attention to output hypotheses #15229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Marco Gaido <[email protected]>
2de6160 to
21d5bb8
Compare
Signed-off-by: mgaido91 <[email protected]>
nithinraok
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Marco. great work. Added comments. Also,
Could you add an option something like preserve_xattn_scores, so when enabled through
decoding_cfg = MultiTaskDecodingConfig(
strategy="beam", # or "greedy"
preserve_xattn_scores=True,
)only store and return xattn_scores (to save memory by default)
| last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction. | ||
| xatt_scores (Optional): List of cross-attention scores for each decoder layer. Each element of the list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn;t shape is List[BxHxT1xT2] . Also best to add: this is used with AED models
| ) | ||
| if xatt_scores_list is not None: | ||
| for layer in range(len(xatt_scores_list)): | ||
| xatt_scores_list[layer] = torch.cat((xatt_scores_list[layer], new_xatt_scores_list[layer]), dim=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about condition when new_xattn_scores_list is None? cat would fail
| pos=0, | ||
| return_scores: bool = True, | ||
| ): | ||
| log_probs, decoder_mems_list, _ = super()._one_step_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you update here as well and also include in returns tuple
|
|
||
| # select xatt scores corresponding to chosen hypotheses | ||
| if next_xatt_scores_list is not None: | ||
| num_heads = xatt_scores_list[0].shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for xatt_scores_list if None
| return prefixes, scores * len_penalties, tgt | ||
| return prefixes, scores * len_penalties, tgt, xatt_scores_list | ||
| else: | ||
| return tgt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might also return xatt_scores_list here as return_beam_scores is independent of return xattn_scores.
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
The PR adds the encoder-decoder cross-attention to the output hypotheses returned by ASR models.
Collection: ASR
Changelog
Usage
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
@nithinraok @andrusenkoau
Additional Information