Skip to content

Commit

Permalink
Add boiler plate code to Tensor subclass (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Feb 7, 2025
1 parent d1e6c03 commit cc6244c
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,27 @@ def _get_tensor_impl_constructor(
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]


def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
# very useful for most of the tensor subclasses
# if in the future there are use cases that need these, we'd recommend
# to override `_get_to_kwargs` and return these args
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
kwargs = {
"device": device,
"dtype": dtype,
}
return kwargs


class TorchAOBaseTensor(torch.Tensor):
"""A util tensor subclass that provides commonly used functions
new tensor subclass can inherit it to get all the utility functions
Expand Down Expand Up @@ -552,26 +573,24 @@ class PlainAQTTensorImpl(...):
__torch_function__ = classmethod(_dispatch__torch_function__)
register_layout = classmethod(_register_layout)
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
_get_to_kwargs = _get_to_kwargs

def __tensor_flatten__(self):
raise NotImplementedError("Subclasses must implement __tensor_flatten__")

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")

def __repr__(self):
raise NotImplementedError("Subclasses must implement __repr__")

def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
# very useful for most of the tensor subclasses
# if in the future there are use cases that need these, we'd recommend
# to override `_get_to_kwargs` and return these args
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
kwargs = {
"device": device,
"dtype": dtype,
}
return kwargs
def get_layout(self):
if not hasattr(self, "_layout"):
return None
return self._layout


def fill_defaults(args, n, defaults_tail):
Expand Down

0 comments on commit cc6244c

Please sign in to comment.