Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
linoybu committed Feb 3, 2025
1 parent a2357af commit 4b2fd74
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale, is_gaudi2)
apply_fp8_linear, cutlass_fp8_supported, get_gaudi2_scale_factor, is_gaudi2, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand Down Expand Up @@ -84,8 +84,7 @@ def process_weights_after_loading(self, layer) -> None:
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
input_scale = layer.input_scale.max()
if is_gaudi2():
input_scale = input_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
input_scale = input_scale * get_gaudi2_scale_factor()
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,17 @@ def is_gaudi2():
return current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2

def get_gaudi2_scale_factor():
return (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)

def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if is_gaudi2():
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
max_w_scale = max_w_scale * get_gaudi2_scale_factor()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down

0 comments on commit 4b2fd74

Please sign in to comment.