Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/convert_bmm_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ConvertBmmToMatmul(ExportPass):
bmm = exir_ops.edge.aten.bmm.default
matmul = exir_ops.edge.aten.matmul.default
patterns = [
{view_copy: 3, bmm: 1},
{expand_copy: 2, view_copy: 3, bmm: 1},
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
{bmm: 1},
Expand Down
5 changes: 5 additions & 0 deletions backends/qualcomm/_passes/remove_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def __init__(self, quantization_capture=False):
torch.ops.aten.alias.default: self._default_condition,
exir_ops.edge.aten.alias.default: self._default_condition,
exir_ops.edge.aten.alias_copy.default: self._default_condition,
exir_ops.edge.aten.expand_copy.default: self._same_shape_condition,
exir_ops.edge.aten.lift_fresh_copy.default: self._default_condition,
exir_ops.edge.aten.repeat.default: self._same_shape_condition,
# remove this target if '_skip_dim_order' is set to False
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
Expand All @@ -43,6 +45,9 @@ def __init__(self, quantization_capture=False):
def _dim_order_op_condition(self, node):
return node.meta["val"].dtype == node.args[0].meta["val"].dtype

def _same_shape_condition(self, node):
return node.args[0].meta["val"].shape == node.meta["val"].shape

def _to_copy_op_condition(self, node):
return "memory_format" in node.kwargs

Expand Down
86 changes: 67 additions & 19 deletions backends/qualcomm/builders/op_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW
from .qnn_constants import OpReshape, OpTile, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
Expand All @@ -27,38 +28,85 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
# The aten copy support broadcasting, and therefore translated to
# Reshape and Tile.
# e.g., torch.ops.aten.copy.default(torch.rand(3,4,5), torch.rand(4,1))
input_node = self.get_node(node.args[1])
input_tensor = self.get_tensor(input_node, node)
copy_inp_tensor_wrapper = self.define_tensor(
output_tensor = self.get_tensor(node, node)
should_insert_tile = input_tensor.numel() != output_tensor.numel()
reshape_node_name = (
node.name + "_unsqueeze" if should_insert_tile else node.name
)
reshape_inp_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

copy_input_tensors = [copy_inp_tensor_wrapper]

reshape_input_tensors = [reshape_inp_tensor_wrapper]
if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
# Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none
node.meta[QCOM_QUANT_ATTRS] = quant_attrs
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,

reshape_tensor = input_tensor
while len(reshape_tensor.shape) < len(output_tensor.shape):
reshape_tensor = reshape_tensor.unsqueeze(0)
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
input_node, node
)
reshape_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=reshape_node_name,
tensor_type=PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
dtype=self.get_data_type(reshape_tensor, input_quant_configs),
quant_encoding=input_quant_encoding,
quant_configs=input_quant_configs,
dims=reshape_tensor.size(),
tensor=reshape_tensor,
is_fake_tensor=True,
nodes_to_wrappers=nodes_to_wrappers,
)
copy_output_tensors = [output_tensor_wrapper]
reshape_output_tensors = [reshape_tensor_wrapper]

copy_op = PyQnnManager.PyQnnOpWrapper(
node.name,
reshape_op = PyQnnManager.PyQnnOpWrapper(
reshape_node_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpReshape.op_name,
)
copy_op.AddInputTensors(copy_input_tensors)
copy_op.AddOutputTensors(copy_output_tensors)

return copy_op
reshape_op.AddInputTensors(reshape_input_tensors)
reshape_op.AddOutputTensors(reshape_output_tensors)
op_wrapper_list = [reshape_op]
if should_insert_tile:
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
tile_output_tensors = [output_tensor_wrapper]
tile_op = PyQnnManager.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpTile.op_name,
)
multiples = []
for i in range(len(reshape_tensor.shape)):
assert (
output_tensor.shape[i] % reshape_tensor.shape[i] == 0
), f"Shape mismatch at dim {i}: {output_tensor.shape[i]} not divisible by {reshape_tensor.shape[i]}"
multiples.append(output_tensor.shape[i] // reshape_tensor.shape[i])
tile_op.AddTensorParam(
OpTile.param_multiples,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
1,
[len(reshape_tensor.shape)],
np.array(multiples, dtype=np.uint32),
True,
)
tile_op.AddInputTensors(reshape_output_tensors)
tile_op.AddOutputTensors(tile_output_tensors)
op_wrapper_list.append(tile_op)
return op_wrapper_list
14 changes: 7 additions & 7 deletions backends/qualcomm/builders/op_index_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
tile_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + f"_indices_tile_{i}",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpTile.op_name,
)
Expand Down Expand Up @@ -231,7 +231,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
reshape_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + f"_reshape_{i}",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpReshape.op_name,
)
Expand Down Expand Up @@ -265,7 +265,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
tile_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + f"_tile_{i}",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpTile.op_name,
)
Expand Down Expand Up @@ -309,7 +309,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
concat_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + "_concat",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpConcat.op_name,
)
Expand Down Expand Up @@ -367,7 +367,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
value_reshape_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + "_value_reshape",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpReshape.op_name,
)
Expand Down Expand Up @@ -404,7 +404,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
value_tile_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + "_value_tile",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpTile.op_name,
)
Expand Down Expand Up @@ -461,7 +461,7 @@ def define_node( # noqa: C901
nodes_to_wrappers=nodes_to_wrappers,
)
target_index_reshape_op = PyQnnManager.PyQnnOpWrapper(
node.name,
node.name + "_target_index_reshape",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpReshape.op_name,
)
Expand Down
54 changes: 31 additions & 23 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
QuantizationSpec,
)

DEFAULT_EPS_8BIT = 0.0001 / 255
DEFAULT_EPS_16BIT = 0.0001 / 65535


@dataclass(eq=True)
class QuantizationConfig:
Expand Down Expand Up @@ -104,10 +107,12 @@ def _derive_bias_qparams_fn(


def get_8a8w_qnn_ptq_config(
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
act_symmetric: bool = False,
act_observer=MovingAverageMinMaxObserver,
eps: float = None,
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 255
extra_args: Dict[str, Any] = {"eps": 2**-21}
# the smallest scale defaults to DEFAULT_EPS_8BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT}

act_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
Expand Down Expand Up @@ -146,10 +151,12 @@ def get_8a8w_qnn_ptq_config(


def get_8a4w_qnn_ptq_config(
act_symmetric: bool = True, act_observer=MovingAverageMinMaxObserver
act_symmetric: bool = True,
act_observer=MovingAverageMinMaxObserver,
eps: float = None,
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 255
extra_args: Dict[str, Any] = {"eps": 2**-21}
# the smallest defaults to DEFAULT_EPS_8BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT}

if act_symmetric:
# If zero_point is 128, htp can do optimizations.
Expand Down Expand Up @@ -203,10 +210,10 @@ def get_8a4w_qnn_ptq_config(

# 4 bits quantization only supports specific ops.
def get_16a4w_qnn_ptq_config(
act_observer=MovingAverageMinMaxObserver,
act_observer=MovingAverageMinMaxObserver, eps: float = None
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 65535
extra_args: Dict[str, Any] = {"eps": 2**-29}
# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
act_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
Expand Down Expand Up @@ -243,10 +250,10 @@ def get_16a4w_qnn_ptq_config(


def get_16a8w_qnn_ptq_config(
act_observer=MovingAverageMinMaxObserver,
act_observer=MovingAverageMinMaxObserver, eps: float = None
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 65535
extra_args: Dict[str, Any] = {"eps": 2**-29}
# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
act_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
Expand Down Expand Up @@ -281,10 +288,10 @@ def get_16a8w_qnn_ptq_config(


def get_16a8w_qnn_qat_config(
act_observer=MovingAverageMinMaxObserver,
act_observer=MovingAverageMinMaxObserver, eps: float = None
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 65535
extra_args: Dict[str, Any] = {"eps": 2**-29}
# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
Expand Down Expand Up @@ -339,10 +346,10 @@ def get_16a8w_qnn_qat_config(


def get_16a16w_qnn_ptq_config(
act_observer=MovingAverageMinMaxObserver,
act_observer=MovingAverageMinMaxObserver, eps: float = None
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 65535
extra_args: Dict[str, Any] = {"eps": 2**-29}
# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
act_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
Expand Down Expand Up @@ -385,10 +392,10 @@ def get_ptq_per_channel_quant_config(
act_observer=MovingAverageMinMaxObserver,
act_symmetric: bool = False,
ch_axis: int = 0,
eps: float = None,
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 65535
extra_args: Dict[str, Any] = {"eps": 2**-29}

# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
supported_act_types = {
torch.uint8,
torch.uint16,
Expand Down Expand Up @@ -457,9 +464,10 @@ def get_ptq_per_block_quant_config(
act_observer=MovingAverageMinMaxObserver,
act_symmetric: bool = False,
ch_axis: int = 0,
eps: float = None,
) -> QuantizationConfig:
# the smallest scale: 0.0001 / 65535
extra_args: Dict[str, Any] = {"eps": 2**-29}
# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}
quantization_config = get_ptq_per_channel_quant_config(
act_dtype=act_dtype,
weight_dtype=weight_dtype,
Expand Down
14 changes: 10 additions & 4 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ModuleQConfig:
is_conv_per_channel: bool = False
is_linear_per_channel: bool = False
act_observer: Optional[UniformQuantizationObserverBase] = None
eps: Optional[float] = None

def __post_init__(self):
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
Expand All @@ -172,9 +173,9 @@ def __post_init__(self):
per_block_quant_config_func,
) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)]
self.quant_config = (
quant_config_func(act_observer=self.act_observer)
quant_config_func(act_observer=self.act_observer, eps=self.eps)
if self.act_observer
else quant_config_func()
else quant_config_func(eps=self.eps)
)

# Assume per_channel_quant/per_block_quant only happen on axis_0 or axis_1, increase the range if there's a need
Expand All @@ -185,10 +186,12 @@ def __post_init__(self):
self.per_channel_quant_config_list.append(
(
per_channel_quant_config_func(
act_observer=self.act_observer, ch_axis=i
act_observer=self.act_observer,
ch_axis=i,
eps=self.eps,
)
if self.act_observer
else per_channel_quant_config_func(ch_axis=i)
else per_channel_quant_config_func(ch_axis=i, eps=self.eps)
)
)

Expand Down Expand Up @@ -409,6 +412,7 @@ def set_default_quant_config(
is_conv_per_channel=False,
is_linear_per_channel=False,
act_observer=None,
eps=None,
) -> None:
"""
Set the default quant config for quantizer.
Expand All @@ -419,6 +423,7 @@ def set_default_quant_config(
is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations.
is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations.
act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`.
eps (float): Minimum scale for quantization.

"""
self.default_quant_config = ModuleQConfig(
Expand All @@ -427,6 +432,7 @@ def set_default_quant_config(
is_conv_per_channel=is_conv_per_channel,
is_linear_per_channel=is_linear_per_channel,
act_observer=act_observer,
eps=eps,
)

def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
Expand Down
Loading
Loading