Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/models/lasr/configuration_lasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class LasrCTCConfig(PreTrainedConfig):
encoder_config (`Union[dict, LasrEncoderConfig]`, *optional*):
The config object or dictionary of the encoder.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id. Also used as blank token id.
Padding token id. Also used as blank token id..
Example:
```python
>>> from transformers import LasrForCTC, LasrCTCConfig
Expand Down Expand Up @@ -240,5 +240,9 @@ def from_encoder_config(cls, encoder_config: LasrEncoderConfig, **kwargs):

return cls(encoder_config=encoder_config.to_dict(), **kwargs)

@property
def inputs_to_logits_ratio(self):
return self.encoder_config.subsampling_conv_stride**2


__all__ = ["LasrEncoderConfig", "LasrCTCConfig"]
6 changes: 5 additions & 1 deletion src/transformers/models/lasr/modular_lasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class LasrCTCConfig(ParakeetCTCConfig):
encoder_config (`Union[dict, LasrEncoderConfig]`, *optional*):
The config object or dictionary of the encoder.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id. Also used as blank token id.
Padding token id. Also used as blank token id..
Example:
```python
>>> from transformers import LasrForCTC, LasrCTCConfig
Expand Down Expand Up @@ -291,6 +291,10 @@ def __init__(
**kwargs,
)

@property
def inputs_to_logits_ratio(self):
return self.encoder_config.subsampling_conv_stride**2


class LasrEncoderSubsampling(nn.Module):
def __init__(self, config: LasrEncoderConfig):
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def __init__(
# set the model type so we can check we have the right pre- and post-processing parameters
if model.config.model_type == "whisper":
self.type = "seq2seq_whisper"
elif model.config.model_type == "lasr_ctc":
self.type = "lasr_ctc"
elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
self.type = "seq2seq"
elif (
Expand Down Expand Up @@ -448,9 +450,16 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
# Currently chunking is not possible at this level for `seq2seq` so
# it's ok.
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)

if self.type == "lasr_ctc":
# TODO: find a standard for that but not easy because input length -> mel length depends on the feature extractor
# specific way of doing it
# means the model take mel features as input, we align according to the hop length
align_to *= self.feature_extractor.hop_length

chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to))
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to))
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to))

if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")
Expand Down