Skip to content

Commit

Permalink
Q dq layout (#1642)
Browse files Browse the repository at this point in the history
* add q-dq layout for ET

* up

* up

* up

* up

* up

* up

* up
  • Loading branch information
metascroy authored Feb 5, 2025
1 parent 4df4d03 commit bc1530b
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 155 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/torchao_experimental_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ jobs:
conda activate venv
pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104"
pip install numpy
pip install pytest
USE_CPP=1 pip install .
- name: Run tests
run: |
conda activate venv
python torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
61 changes: 61 additions & 0 deletions torchao/experimental/q_dq_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

import sys

handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)


from torchao.dtypes.utils import PlainLayout


class QDQLayout(PlainLayout):
pass


from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl


@register_layout(QDQLayout)
class _Impl(PlainAQTTensorImpl):
pass


def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, QDQLayout)


def _linear_impl(input_tensor, weight_tensor, bias):
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
)
186 changes: 186 additions & 0 deletions torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import tempfile
import unittest

import torch
from torch.testing import FileCheck

from torchao.dtypes import PlainLayout
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.experimental.q_dq_layout import QDQLayout
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.granularity import (
PerGroup,
PerRow,
)
from torchao.quantization.quant_api import quantize_
from torchao.utils import unwrap_tensor_subclass


class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
def test_accuracy(self):
"""
Checks the accuracy of different layouts by comparing the results to PlainLayout()
"""
m = 1
n = 1071
k = 4096
activations = torch.randn(m, k)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

reference_layout = PlainLayout()
test_layouts = [
PackedLinearInt8DynamicActivationIntxWeightLayout(),
QDQLayout(),
]
test_weight_dtypes = [
torch.int1,
torch.int2,
torch.int3,
torch.int4,
torch.int5,
torch.int6,
torch.int7,
torch.int8,
]
test_has_weight_zeros = [True, False]
test_granularities = [PerGroup(128), PerRow()]
for layout, weight_dtype, has_weight_zeros, granularity in itertools.product(
test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities
):
quantized_model = copy.deepcopy(model)
quantize_(
quantized_model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=layout,
),
)

quantized_model_reference = copy.deepcopy(model)
quantize_(
quantized_model_reference,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=reference_layout,
),
)

with torch.no_grad():
result = quantized_model(activations)
expected_result = quantized_model_reference(activations)
self.assertTrue(torch.allclose(result, expected_result, atol=1e-6))

def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
self,
):
"""
Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with
torch.export.export, torch.compile, and AOTI.
"""
granularity = PerRow()
m = 3
k0 = 512
k1 = 256
k2 = 128
k3 = 1024
weight_dtype = torch.int4
has_weight_zeros = True
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, k2, bias=False),
torch.nn.Linear(k2, k3, bias=False),
]
model = torch.nn.Sequential(*layers)
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)

quantize_(
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
)
eager_results = model(activations)

unwrapped_model = copy.deepcopy(model)
unwrap_tensor_subclass(model)

# Export
exported = torch.export.export(model, (activations,), strict=True)
exported_results = exported.module()(activations)
self.assertTrue(torch.allclose(eager_results, exported_results))

# Compile
compiled = torch.compile(unwrapped_model)
with torch.no_grad():
compiled_results = compiled(activations)
self.assertTrue(torch.allclose(eager_results, compiled_results))

# AOTI
with tempfile.TemporaryDirectory() as tmpdirname:
package_path = f"{tmpdirname}/model.pt2"
torch._inductor.aoti_compile_and_package(
exported, package_path=package_path
)
fn = torch._inductor.aoti_load_package(package_path)
aoti_results = fn(activations)
self.assertTrue(torch.allclose(eager_results, aoti_results))

def test_export_QDQLayout(self):
"""
Checks that models quantized with TestQDQLayout() export as expected
"""
granularity = PerGroup(64)
weight_dtype = torch.int4
has_weight_zeros = False
layers = [
torch.nn.Linear(512, 256, bias=False),
]
model = torch.nn.Sequential(*layers)
activations = torch.randn(1, 512, dtype=torch.float32)

quantize_(
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=QDQLayout(),
),
)
eager_results = model(activations)

unwrap_tensor_subclass(model)
exported = torch.export.export(model, (activations,), strict=True)
exported_results = exported.module()(activations)
self.assertTrue(torch.allclose(eager_results, exported_results))

expected_lines = [
"torch.ops.quant.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int32, -128, 127, None, torch.float32, torch.int32)",
"torch.ops.quant.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int32, -128, 127)",
"torch.ops.quant.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int32, -128, 127)",
"torch.ops.quant.dequantize_affine.default(p_fn_0_parametrizations_weight_original0, [1, 64], p_fn_0_parametrizations_weight_original1, None, torch.int32, -8, 7, 'NONE')",
"torch.ops.aten.linear.default(dequantize_affine, dequantize_affine_1)",
]
for line in expected_lines:
FileCheck().check_count(line, 1, exactly=True).run(
exported.graph_module.code
)
Loading

0 comments on commit bc1530b

Please sign in to comment.