Skip to content

Commit

Permalink
Fix: selecting correct backend for MultiHeadAttention (#645)
Browse files Browse the repository at this point in the history
Fix for selecting correct backend for MultiHeadAttention - previous code
always defaulted to _Backend.TORCH_SDPA

---------

Co-authored-by: root <root@adobrzyniewicz-mcz3-g2-mpijob-worker-0.adobrzyniewicz-mcz3-g2-mpijob-worker.framework.svc.cluster.local>
  • Loading branch information
adobrzyniewicz-habana and root authored Feb 3, 2025
1 parent 370953d commit e29b5f5
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def __init__(
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN
} else _Backend.TORCH_SDPA

def forward(
Expand Down Expand Up @@ -279,6 +278,36 @@ def forward(
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.HPU_ATTN:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))

from vllm_hpu_extension.flags import enabled_flags

if "fsdpa" in enabled_flags():
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA

fsdpa_op = ModuleFusedSDPA(FusedSDPA)

out = fsdpa_op(query,
key,
value,
None,
dropout_p=0.0,
is_causal=False,
scale=self.scale,
softmax_mode="fast",
recompute_mode=True,
valid_sequence_lengths=None)
else:
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)

out = out.transpose(1, 2)

return out.reshape(bsz, q_len, -1)


Expand Down

0 comments on commit e29b5f5

Please sign in to comment.