diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c210292865b29..a4f45fc338164 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -240,8 +240,7 @@ def __init__( backend = _Backend.XFORMERS self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN } else _Backend.TORCH_SDPA def forward( @@ -279,6 +278,36 @@ def forward( value, scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.HPU_ATTN: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + + from vllm_hpu_extension.flags import enabled_flags + + if "fsdpa" in enabled_flags(): + from habana_frameworks.torch.hpex.kernels import FusedSDPA + from vllm_hpu_extension.utils import ModuleFusedSDPA + + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + + out = fsdpa_op(query, + key, + value, + None, + dropout_p=0.0, + is_causal=False, + scale=self.scale, + softmax_mode="fast", + recompute_mode=True, + valid_sequence_lengths=None) + else: + out = F.scaled_dot_product_attention(query, + key, + value, + scale=self.scale) + + out = out.transpose(1, 2) + return out.reshape(bsz, q_len, -1)