Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into lint_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 7, 2025
2 parents 2a7d181 + e7aa4ca commit 0c5c467
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build_wheels_aarch64_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ jobs:
test-infra-repository: pytorch/test-infra
test-infra-ref: main
with-cuda: disable

# please note: excluding 3.13t for aarch64 builds for now
python-versions: '["3.9", "3.10", "3.11", "3.12", "3.13"]'
build:
needs: generate-matrix
permissions:
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/build_wheels_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ jobs:
with-cuda: enable
with-rocm: enable
with-xpu: enable
# please note: excluding 3.13t for aarch64 builds for now
python-versions: '["3.9", "3.10", "3.11", "3.12", "3.13"]'

build:
needs: generate-matrix
Expand Down Expand Up @@ -89,5 +91,5 @@ jobs:
Error Information:
${{ needs.generate-matrix.result == 'failure' && 'Matrix generation failed' || '' }}
${{ needs.build.result == 'failure' && 'Build job failed' || '' }}
This is an automated notification. Please check the GitHub Actions page for more details about the failure.
2 changes: 2 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ for _ in range(10):

## float8 linear with delayed scaling

:warning: <em>We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details.</em>

This is theoretically the most performant recipe as it minimizes memory reads.

```python
Expand Down
10 changes: 10 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ def __post_init__(self):
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)

# Future deprecation warning for delayed scaling
if (
self.cast_config_input.scaling_type != ScalingType.DYNAMIC
or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC
or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
):
logger.warning(
"Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details."
)


# Pre-made recipes for common configurations
# TODO(future PR): go through a round of design on this, and eventually expose
Expand Down
57 changes: 38 additions & 19 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,27 @@ def _get_tensor_impl_constructor(
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]


def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
# very useful for most of the tensor subclasses
# if in the future there are use cases that need these, we'd recommend
# to override `_get_to_kwargs` and return these args
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
kwargs = {
"device": device,
"dtype": dtype,
}
return kwargs


class TorchAOBaseTensor(torch.Tensor):
"""A util tensor subclass that provides commonly used functions
new tensor subclass can inherit it to get all the utility functions
Expand Down Expand Up @@ -552,26 +573,24 @@ class PlainAQTTensorImpl(...):
__torch_function__ = classmethod(_dispatch__torch_function__)
register_layout = classmethod(_register_layout)
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
_get_to_kwargs = _get_to_kwargs

def __tensor_flatten__(self):
raise NotImplementedError("Subclasses must implement __tensor_flatten__")

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")

def __repr__(self):
raise NotImplementedError("Subclasses must implement __repr__")

def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
# very useful for most of the tensor subclasses
# if in the future there are use cases that need these, we'd recommend
# to override `_get_to_kwargs` and return these args
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
kwargs = {
"device": device,
"dtype": dtype,
}
return kwargs
def get_layout(self):
if not hasattr(self, "_layout"):
return None
return self._layout


def fill_defaults(args, n, defaults_tail):
Expand Down

0 comments on commit 0c5c467

Please sign in to comment.