Skip to content

Commit

Permalink
add a deprecation warning for float8 delayed and static scaling (#1681)
Browse files Browse the repository at this point in the history
Update

[ghstack-poisoned]
  • Loading branch information
vkuzo authored Feb 7, 2025
1 parent cc6244c commit e7aa4ca
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ for _ in range(10):

## float8 linear with delayed scaling

:warning: <em>We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details.</em>

This is theoretically the most performant recipe as it minimizes memory reads.

```python
Expand Down
10 changes: 10 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ def __post_init__(self):
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)

# Future deprecation warning for delayed scaling
if (
self.cast_config_input.scaling_type != ScalingType.DYNAMIC
or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC
or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
):
logger.warning(
"Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details."
)


# Pre-made recipes for common configurations
# TODO(future PR): go through a round of design on this, and eventually expose
Expand Down

0 comments on commit e7aa4ca

Please sign in to comment.