diff --git a/src/transformers/models/lasr/configuration_lasr.py b/src/transformers/models/lasr/configuration_lasr.py index 28051469be58..41f09ad9e00f 100644 --- a/src/transformers/models/lasr/configuration_lasr.py +++ b/src/transformers/models/lasr/configuration_lasr.py @@ -240,5 +240,9 @@ def from_encoder_config(cls, encoder_config: LasrEncoderConfig, **kwargs): return cls(encoder_config=encoder_config.to_dict(), **kwargs) + @property + def inputs_to_logits_ratio(self): + return self.encoder_config.subsampling_conv_stride**2 + __all__ = ["LasrEncoderConfig", "LasrCTCConfig"] diff --git a/src/transformers/models/lasr/modular_lasr.py b/src/transformers/models/lasr/modular_lasr.py index c02b2ae0f1c3..be4e1465370a 100644 --- a/src/transformers/models/lasr/modular_lasr.py +++ b/src/transformers/models/lasr/modular_lasr.py @@ -291,6 +291,10 @@ def __init__( **kwargs, ) + @property + def inputs_to_logits_ratio(self): + return self.encoder_config.subsampling_conv_stride**2 + class LasrEncoderSubsampling(nn.Module): def __init__(self, config: LasrEncoderConfig): diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f09c529072f8..8e6f8b5cafcd 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -350,6 +350,20 @@ def _sanitize_parameters( return preprocess_params, forward_params, postprocess_params + @property + def _align_to(self): + """Sample stride per output.""" + # XXX: Carefully, this variable will not exist in `seq2seq` setting. + # Currently chunking is not possible at this level for `seq2seq` so + # it's ok. + align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) + if self.model.config.model_type == "lasr_ctc": + # TODO: find a standard for that but not easy because input length -> mel length depends on the feature extractor + # specific way of doing it + # means the model take mel features as input, we align according to the hop length + align_to *= self.feature_extractor.hop_length + return align_to + def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): if isinstance(inputs, str): if inputs.startswith("http://") or inputs.startswith("https://"): @@ -444,10 +458,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): if isinstance(stride_length_s, (int, float)): stride_length_s = [stride_length_s, stride_length_s] - # XXX: Carefully, this variable will not exist in `seq2seq` setting. - # Currently chunking is not possible at this level for `seq2seq` so - # it's ok. - align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) + align_to = self._align_to chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) @@ -567,7 +578,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs): # Send stride to `postprocess`. # it needs to be handled there where # the pieces are to be concatenated. - ratio = 1 / self.model.config.inputs_to_logits_ratio + ratio = 1 / self._align_to if isinstance(stride, tuple): out["stride"] = rescale_stride([stride], ratio)[0] else: @@ -650,11 +661,12 @@ def postprocess( if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}: chunks = [] + align_to = self._align_to for item in offsets: - start = item["start_offset"] * self.model.config.inputs_to_logits_ratio + start = item["start_offset"] * align_to start /= self.feature_extractor.sampling_rate - stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio + stop = item["end_offset"] * align_to stop /= self.feature_extractor.sampling_rate chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})