diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py index 262a3b9ef0f..9de7908e6fa 100644 --- a/backends/qualcomm/_passes/convert_bmm_to_matmul.py +++ b/backends/qualcomm/_passes/convert_bmm_to_matmul.py @@ -25,6 +25,7 @@ class ConvertBmmToMatmul(ExportPass): bmm = exir_ops.edge.aten.bmm.default matmul = exir_ops.edge.aten.matmul.default patterns = [ + {view_copy: 3, bmm: 1}, {expand_copy: 2, view_copy: 3, bmm: 1}, {expand_copy: 2, view_copy: 3, bmm: 1, clone: 1}, {bmm: 1}, diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index d75637beaf5..ccc1cb127ce 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -23,7 +23,9 @@ def __init__(self, quantization_capture=False): torch.ops.aten.alias.default: self._default_condition, exir_ops.edge.aten.alias.default: self._default_condition, exir_ops.edge.aten.alias_copy.default: self._default_condition, + exir_ops.edge.aten.expand_copy.default: self._same_shape_condition, exir_ops.edge.aten.lift_fresh_copy.default: self._default_condition, + exir_ops.edge.aten.repeat.default: self._same_shape_condition, # remove this target if '_skip_dim_order' is set to False exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition, # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True @@ -43,6 +45,9 @@ def __init__(self, quantization_capture=False): def _dim_order_op_condition(self, node): return node.meta["val"].dtype == node.args[0].meta["val"].dtype + def _same_shape_condition(self, node): + return node.args[0].meta["val"].shape == node.meta["val"].shape + def _to_copy_op_condition(self, node): return "memory_format" in node.kwargs diff --git a/backends/qualcomm/builders/op_copy.py b/backends/qualcomm/builders/op_copy.py index a1caa1c98a2..070ae81d775 100644 --- a/backends/qualcomm/builders/op_copy.py +++ b/backends/qualcomm/builders/op_copy.py @@ -7,12 +7,13 @@ import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager +import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor -from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW +from .qnn_constants import OpReshape, OpTile, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor @@ -27,38 +28,85 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: + # The aten copy support broadcasting, and therefore translated to + # Reshape and Tile. + # e.g., torch.ops.aten.copy.default(torch.rand(3,4,5), torch.rand(4,1)) input_node = self.get_node(node.args[1]) input_tensor = self.get_tensor(input_node, node) - copy_inp_tensor_wrapper = self.define_tensor( + output_tensor = self.get_tensor(node, node) + should_insert_tile = input_tensor.numel() != output_tensor.numel() + reshape_node_name = ( + node.name + "_unsqueeze" if should_insert_tile else node.name + ) + reshape_inp_tensor_wrapper = self.define_tensor( input_node, node, input_tensor, PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - - copy_input_tensors = [copy_inp_tensor_wrapper] - + reshape_input_tensors = [reshape_inp_tensor_wrapper] if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = quant_attrs.copy() # Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none node.meta[QCOM_QUANT_ATTRS] = quant_attrs - output_tensor = self.get_tensor(node, node) - output_tensor_wrapper = self.define_tensor( - node, - node, - output_tensor, - PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, + + reshape_tensor = input_tensor + while len(reshape_tensor.shape) < len(output_tensor.shape): + reshape_tensor = reshape_tensor.unsqueeze(0) + input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( + input_node, node + ) + reshape_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=reshape_node_name, + tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=self.get_data_type(reshape_tensor, input_quant_configs), + quant_encoding=input_quant_encoding, + quant_configs=input_quant_configs, + dims=reshape_tensor.size(), + tensor=reshape_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, ) - copy_output_tensors = [output_tensor_wrapper] + reshape_output_tensors = [reshape_tensor_wrapper] - copy_op = PyQnnManager.PyQnnOpWrapper( - node.name, + reshape_op = PyQnnManager.PyQnnOpWrapper( + reshape_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, ) - copy_op.AddInputTensors(copy_input_tensors) - copy_op.AddOutputTensors(copy_output_tensors) - - return copy_op + reshape_op.AddInputTensors(reshape_input_tensors) + reshape_op.AddOutputTensors(reshape_output_tensors) + op_wrapper_list = [reshape_op] + if should_insert_tile: + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + tile_output_tensors = [output_tensor_wrapper] + tile_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + multiples = [] + for i in range(len(reshape_tensor.shape)): + assert ( + output_tensor.shape[i] % reshape_tensor.shape[i] == 0 + ), f"Shape mismatch at dim {i}: {output_tensor.shape[i]} not divisible by {reshape_tensor.shape[i]}" + multiples.append(output_tensor.shape[i] // reshape_tensor.shape[i]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + 1, + [len(reshape_tensor.shape)], + np.array(multiples, dtype=np.uint32), + True, + ) + tile_op.AddInputTensors(reshape_output_tensors) + tile_op.AddOutputTensors(tile_output_tensors) + op_wrapper_list.append(tile_op) + return op_wrapper_list diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index 84eb2368967..512bb244954 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -198,7 +198,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) tile_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + f"_indices_tile_{i}", QNN_OP_PACKAGE_NAME_QTI_AISW, OpTile.op_name, ) @@ -231,7 +231,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) reshape_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + f"_reshape_{i}", QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, ) @@ -265,7 +265,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) tile_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + f"_tile_{i}", QNN_OP_PACKAGE_NAME_QTI_AISW, OpTile.op_name, ) @@ -309,7 +309,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) concat_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + "_concat", QNN_OP_PACKAGE_NAME_QTI_AISW, OpConcat.op_name, ) @@ -367,7 +367,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) value_reshape_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + "_value_reshape", QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, ) @@ -404,7 +404,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) value_tile_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + "_value_tile", QNN_OP_PACKAGE_NAME_QTI_AISW, OpTile.op_name, ) @@ -461,7 +461,7 @@ def define_node( # noqa: C901 nodes_to_wrappers=nodes_to_wrappers, ) target_index_reshape_op = PyQnnManager.PyQnnOpWrapper( - node.name, + node.name + "_target_index_reshape", QNN_OP_PACKAGE_NAME_QTI_AISW, OpReshape.op_name, ) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 77fb989ba44..23627e0d97c 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -28,6 +28,9 @@ QuantizationSpec, ) +DEFAULT_EPS_8BIT = 0.0001 / 255 +DEFAULT_EPS_16BIT = 0.0001 / 65535 + @dataclass(eq=True) class QuantizationConfig: @@ -104,10 +107,12 @@ def _derive_bias_qparams_fn( def get_8a8w_qnn_ptq_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver + act_symmetric: bool = False, + act_observer=MovingAverageMinMaxObserver, + eps: float = None, ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 255 - extra_args: Dict[str, Any] = {"eps": 2**-21} + # the smallest scale defaults to DEFAULT_EPS_8BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT} act_quantization_spec = QuantizationSpec( dtype=torch.uint8, @@ -146,10 +151,12 @@ def get_8a8w_qnn_ptq_config( def get_8a4w_qnn_ptq_config( - act_symmetric: bool = True, act_observer=MovingAverageMinMaxObserver + act_symmetric: bool = True, + act_observer=MovingAverageMinMaxObserver, + eps: float = None, ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 255 - extra_args: Dict[str, Any] = {"eps": 2**-21} + # the smallest defaults to DEFAULT_EPS_8BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT} if act_symmetric: # If zero_point is 128, htp can do optimizations. @@ -203,10 +210,10 @@ def get_8a4w_qnn_ptq_config( # 4 bits quantization only supports specific ops. def get_16a4w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, + act_observer=MovingAverageMinMaxObserver, eps: float = None ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 65535 - extra_args: Dict[str, Any] = {"eps": 2**-29} + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -243,10 +250,10 @@ def get_16a4w_qnn_ptq_config( def get_16a8w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, + act_observer=MovingAverageMinMaxObserver, eps: float = None ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 65535 - extra_args: Dict[str, Any] = {"eps": 2**-29} + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -281,10 +288,10 @@ def get_16a8w_qnn_ptq_config( def get_16a8w_qnn_qat_config( - act_observer=MovingAverageMinMaxObserver, + act_observer=MovingAverageMinMaxObserver, eps: float = None ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 65535 - extra_args: Dict[str, Any] = {"eps": 2**-29} + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -339,10 +346,10 @@ def get_16a8w_qnn_qat_config( def get_16a16w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, + act_observer=MovingAverageMinMaxObserver, eps: float = None ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 65535 - extra_args: Dict[str, Any] = {"eps": 2**-29} + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, @@ -385,10 +392,10 @@ def get_ptq_per_channel_quant_config( act_observer=MovingAverageMinMaxObserver, act_symmetric: bool = False, ch_axis: int = 0, + eps: float = None, ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 65535 - extra_args: Dict[str, Any] = {"eps": 2**-29} - + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} supported_act_types = { torch.uint8, torch.uint16, @@ -457,9 +464,10 @@ def get_ptq_per_block_quant_config( act_observer=MovingAverageMinMaxObserver, act_symmetric: bool = False, ch_axis: int = 0, + eps: float = None, ) -> QuantizationConfig: - # the smallest scale: 0.0001 / 65535 - extra_args: Dict[str, Any] = {"eps": 2**-29} + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} quantization_config = get_ptq_per_channel_quant_config( act_dtype=act_dtype, weight_dtype=weight_dtype, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 0d54b250bfd..d95176c446c 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -160,6 +160,7 @@ class ModuleQConfig: is_conv_per_channel: bool = False is_linear_per_channel: bool = False act_observer: Optional[UniformQuantizationObserverBase] = None + eps: Optional[float] = None def __post_init__(self): if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT: @@ -172,9 +173,9 @@ def __post_init__(self): per_block_quant_config_func, ) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)] self.quant_config = ( - quant_config_func(act_observer=self.act_observer) + quant_config_func(act_observer=self.act_observer, eps=self.eps) if self.act_observer - else quant_config_func() + else quant_config_func(eps=self.eps) ) # Assume per_channel_quant/per_block_quant only happen on axis_0 or axis_1, increase the range if there's a need @@ -185,10 +186,12 @@ def __post_init__(self): self.per_channel_quant_config_list.append( ( per_channel_quant_config_func( - act_observer=self.act_observer, ch_axis=i + act_observer=self.act_observer, + ch_axis=i, + eps=self.eps, ) if self.act_observer - else per_channel_quant_config_func(ch_axis=i) + else per_channel_quant_config_func(ch_axis=i, eps=self.eps) ) ) @@ -409,6 +412,7 @@ def set_default_quant_config( is_conv_per_channel=False, is_linear_per_channel=False, act_observer=None, + eps=None, ) -> None: """ Set the default quant config for quantizer. @@ -419,6 +423,7 @@ def set_default_quant_config( is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations. is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations. act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`. + eps (float): Minimum scale for quantization. """ self.default_quant_config = ModuleQConfig( @@ -427,6 +432,7 @@ def set_default_quant_config( is_conv_per_channel=is_conv_per_channel, is_linear_per_channel=is_linear_per_channel, act_observer=act_observer, + eps=eps, ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 17dc6bf4e19..e64d58a8971 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -447,7 +447,6 @@ void QnnManager::DestroyContext() { bool QnnManager::IsNodeSupportedByBackend( std::vector>& op_wrappers) { Qnn_ErrorHandle_t error = QNN_SUCCESS; - for (std::shared_ptr& op_wrapper : op_wrappers) { for (const auto& param : op_wrapper->GetParams()) { // unused? @@ -516,14 +515,33 @@ Error QnnManager::CompileDlc() { std::vector> graph_inputs, graph_outputs, tensors; + // Mapping memory address for the input and output of mutable buffer + std::unordered_map mutable_buffer_id_to_memory_map; for (int i = 0; i < graphInfo.numInputTensors; ++i) { auto tw = CreateTensorWrapper(graphInfo.inputTensors[i]); tw->UpdateQnnTensorMeta(graphInfo.inputTensors[i]); + + int mutable_buffer_id = ExtractMutableBufferNumber(tw->GetName()); + if (mutable_buffer_id != -1) { + // Delegate maintains the memory for mutable buffer + tw->AllocateDataBuffer(); + mutable_buffer_id_to_memory_map[mutable_buffer_id] = + tw->GetStaticTensorData(); + } graph_inputs.push_back(tw); } for (int i = 0; i < graphInfo.numOutputTensors; ++i) { auto tw = CreateTensorWrapper(graphInfo.outputTensors[i]); tw->UpdateQnnTensorMeta(graphInfo.outputTensors[i]); + int mutable_buffer_id = ExtractMutableBufferNumber(tw->GetName()); + if (mutable_buffer_id != -1 && + mutable_buffer_id_to_memory_map.find(mutable_buffer_id) != + mutable_buffer_id_to_memory_map.end()) { + // Fill the same memory for I/O of mutable buffer + tw->FillDataBuffer( + mutable_buffer_id_to_memory_map[mutable_buffer_id], + false /* copy_data */); + } graph_outputs.push_back(tw); } diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 2b73e0c6dfb..006827a2785 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -828,6 +828,17 @@ def forward(self, x): return self.conv_transpose(x) +class Copy(torch.nn.Module): + def __init__(self, x): + super().__init__() + self.x = x + + def forward(self, y): + # +1 to workaround that copy has no quant config + x = torch.ops.aten.copy.default(self.x, y + 1) + return x + 1 + + class Cos(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index c57dbbcc332..381524135a9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1982,6 +1982,17 @@ def test_qnn_backend_conv2d_topk(self): sample_input = (torch.randn(1, 3, 32, 32),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_copy(self): + sample_inputs = [ + (torch.randn(3, 4, 5),), + (torch.randn(4, 1),), + ] + for i, sample_input in enumerate(sample_inputs): + with self.subTest(i=i): + self.lower_module_and_test_output( + Copy(torch.randn(3, 4, 5)), sample_input # noqa: F405, + ) + def test_qnn_backend_einsum_outer_product_relu(self): module = EinsumOuterProductRelu() # noqa: F405 x = torch.randn(5) @@ -4342,6 +4353,18 @@ def test_qnn_backend_conv2d_topk(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_copy(self): + sample_inputs = [ + (torch.randn(3, 4, 5),), + (torch.randn(4, 1),), + ] + for i, sample_input in enumerate(sample_inputs): + with self.subTest(i=i): + module = self.get_qdq_module( + Copy(torch.randn(3, 4, 5)), sample_input # noqa: F405 + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_einsum_outer_product_relu(self): module = EinsumOuterProductRelu() # noqa: F405 x = torch.randn(5) diff --git a/examples/qualcomm/oss_scripts/t5/t5_model.py b/examples/qualcomm/oss_scripts/t5/t5_model.py index 0593feaa8b8..2d6e71c41af 100644 --- a/examples/qualcomm/oss_scripts/t5/t5_model.py +++ b/examples/qualcomm/oss_scripts/t5/t5_model.py @@ -443,15 +443,26 @@ def __init__( device="cpu", dtype=torch.float32, ) + head_dim = getattr( + self.config, + "head_dim", + self.config.hidden_size // self.config.num_attention_heads, + ) + num_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) + self.static_cache.early_initialization( + batch_size, num_heads, head_dim, torch.float32, "cpu" + ) # Register cache buffers to make them exportable - for i in range(len(self.static_cache.key_cache)): + for i in range(len(self.static_cache.layers)): self.register_buffer( - f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False + f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False ) self.register_buffer( f"value_cache_{i}", - self.static_cache.value_cache[i], + self.static_cache.layers[i].values, persistent=False, ) diff --git a/examples/qualcomm/oss_scripts/whisper/whisper.py b/examples/qualcomm/oss_scripts/whisper/whisper.py index 7323a34247f..a4821e64209 100644 --- a/examples/qualcomm/oss_scripts/whisper/whisper.py +++ b/examples/qualcomm/oss_scripts/whisper/whisper.py @@ -225,6 +225,7 @@ def quantize( per_channel_linear=True, act_observer=MinMaxObserver, custom_annotations=custom_annotations, + eps=2**-20, ) with torch.no_grad(): diff --git a/examples/qualcomm/oss_scripts/whisper/whisper_model.py b/examples/qualcomm/oss_scripts/whisper/whisper_model.py index 22437c51044..81eae0fa59a 100644 --- a/examples/qualcomm/oss_scripts/whisper/whisper_model.py +++ b/examples/qualcomm/oss_scripts/whisper/whisper_model.py @@ -58,6 +58,22 @@ def __init__(self, whisper_model, max_cache_length, batch_size): device="cpu", dtype=torch.float32, ) + head_dim = getattr( + self.config, + "head_dim", + self.config.hidden_size // self.config.num_attention_heads, + ) + num_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) + self.static_cache.early_initialization( + batch_size, num_heads, head_dim, torch.float32, "cpu" + ) + for idx in range(len(self.static_cache.layers)): + self.register_buffer(f"key_cache_{idx}", self.static_cache.layers[idx].keys) + self.register_buffer( + f"value_cache_{idx}", self.static_cache.layers[idx].values + ) self.cache = EncoderDecoderCache(self.static_cache, DynamicCache()) def forward( diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index ca1d655c0db..0ffe23d14f0 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -305,6 +305,7 @@ def make_quantizer( act_observer=MovingAverageMinMaxObserver, is_qat=False, submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, + eps=None, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) @@ -314,6 +315,7 @@ def make_quantizer( is_conv_per_channel=per_channel_conv, is_linear_per_channel=per_channel_linear, act_observer=act_observer, + eps=eps, ) submodule_qconfig_list = submodule_qconfig_list or [] quantizer.set_submodule_qconfig_list(submodule_qconfig_list)