diff --git a/torchao/utils.py b/torchao/utils.py index f67463f9f..13b59c2e8 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -512,6 +512,27 @@ def _get_tensor_impl_constructor( return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] +def _get_to_kwargs(self, *args, **kwargs): + # `torch._C._nn._parse_to` can't handle `layout` argument + for arg in args: + if isinstance(arg, torch.layout): + args.remove(arg) + if "layout" in kwargs: + kwargs.pop("layout") + # ignoring `non_blocking` and `memory_format` args since these are not + # very useful for most of the tensor subclasses + # if in the future there are use cases that need these, we'd recommend + # to override `_get_to_kwargs` and return these args + device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + kwargs = { + "device": device, + "dtype": dtype, + } + return kwargs + + class TorchAOBaseTensor(torch.Tensor): """A util tensor subclass that provides commonly used functions new tensor subclass can inherit it to get all the utility functions @@ -552,26 +573,24 @@ class PlainAQTTensorImpl(...): __torch_function__ = classmethod(_dispatch__torch_function__) register_layout = classmethod(_register_layout) get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor) + _get_to_kwargs = _get_to_kwargs + + def __tensor_flatten__(self): + raise NotImplementedError("Subclasses must implement __tensor_flatten__") + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + raise NotImplementedError("Subclasses must implement __tensor_unflatten__") + + def __repr__(self): + raise NotImplementedError("Subclasses must implement __repr__") - def _get_to_kwargs(self, *args, **kwargs): - # `torch._C._nn._parse_to` can't handle `layout` argument - for arg in args: - if isinstance(arg, torch.layout): - args.remove(arg) - if "layout" in kwargs: - kwargs.pop("layout") - # ignoring `non_blocking` and `memory_format` args since these are not - # very useful for most of the tensor subclasses - # if in the future there are use cases that need these, we'd recommend - # to override `_get_to_kwargs` and return these args - device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - kwargs = { - "device": device, - "dtype": dtype, - } - return kwargs + def get_layout(self): + if not hasattr(self, "_layout"): + return None + return self._layout def fill_defaults(args, n, defaults_tail):