Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong result and no speedup with SemiSparseLinear from Torchao compared to torch.nn.Linear #1617

Open
lin-ht opened this issue Jan 24, 2025 · 3 comments

Comments

@lin-ht
Copy link

lin-ht commented Jan 24, 2025

Hi, I tested a modified sample code from the tutorial to check the performance gain and the accuracy of the SemiSparseLinear. I found out that the SemiSparseLinear produces wrong results and is much slower than torch.nn.Linear on H100 GPU. The testing code is attached below. Is there anything I done incorrectly here?

import torch
# from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
# SparseSemiStructuredTensor._FORCE_CUTLASS = False
# Modification: use SemiSparseLinear from torchao.
from torchao.sparsity.training import (
    SemiSparseLinear,
)

# Problem scale
in_f = 10240
out_f = 3072

# mask Linear weight to be 2:4 sparse
# Modification: torchao SemiSparseLinear will jointly sparsify the A nd A', so we construct mask with 4x4 blocks.
mask = torch.Tensor([[0, 0, 1, 1],[0, 0, 1, 1],[1, 1, 0, 0],[1, 1, 0, 0]]).tile((out_f//4, in_f//4)).cuda().bool()
linear = torch.nn.Linear(in_f, out_f).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)

x = torch.rand(out_f, in_f).half().cuda()

with torch.inference_mode():
    dense_output = linear(x)
    dense_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # Error on accelerate via SparseSemiStructuredTensor:
    # RuntimeError: sparse_semi_structured_mad_op : Supported only on GPUs with compute capability 8.x
    # linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
    # Modification: Use the SemiSparseLinear from torchao instead
    linear_sparse = SemiSparseLinear.from_dense(linear)
    # The sparsification is dynamic in forward func of SemiSparseLinear, so the weight is identical to linear.
    assert id(linear_sparse.weight)==id(linear.weight)

    sparse_output = linear_sparse(x)
    sparse_t = Timer(stmt="linear_sparse(x)",
                     globals={"linear_sparse": linear_sparse,
                              "x": x}).blocked_autorange().median * 1e3

    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")

    abs_diff = torch.abs(sparse_output - dense_output)
    max_error = torch.max(abs_diff)
    max_error_index = torch.argmax(abs_diff)
    max_error_coords = torch.unravel_index(max_error_index, sparse_output.shape)
    print(f"Max error: {max_error.item()} at index {max_error_coords}")

    # sparse and dense matmul are numerically equivalent
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)
@lin-ht lin-ht changed the title No speedup with SemiSparseLinear from Torchao compared to torch.nn.Linear Wrong result and no speedup with SemiSparseLinear from Torchao compared to torch.nn.Linear Jan 24, 2025
@lin-ht
Copy link
Author

lin-ht commented Jan 24, 2025

The output from my run:

/opt/venv/lib/python3.10/site-packages/torch/sparse/semi_structured.py:114: UserWarning: The PyTorch API of SparseSemiStructuredTensor is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.sparse module for further information about the project.
  warnings.warn(
Dense: 0.309ms Sparse: 0.927ms | Speedup: 0.334x
Max error: 5.94140625 at index (tensor(1169, device='cuda:0'), tensor(1892, device='cuda:0'))
Traceback (most recent call last):
  File "/mnt/localssd/halin/colligo/contrib/Mori/tests/test_wild.py", line 50, in <module>
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)
AssertionError

@jerryzh168
Copy link
Contributor

cc @jcaip

@jcaip
Copy link
Contributor

jcaip commented Jan 31, 2025

cc @lin-ht SemiSparseLinear is only supported for Ampere architecture unfortunately. This should throw an errror/warning though and not just fail silently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants