From 36c76766171e18569d25dd602dc29ec2e1fe820f Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Tue, 11 Feb 2025 10:08:28 +0530 Subject: [PATCH] Refactor long-context + LoRA flow (#807) This PR refactors long-context + LoRA flow to align with the upstream main branch https://github.com/vllm-project/vllm/pull/12812. HPU requires special handling while creating `long_lora_offsets_tensor` in `convert_mapping`. [(refer)](https://github.com/HabanaAI/vllm-fork/blob/b0a4e825370434ccaa30574ba9d20311fccdab36/vllm/lora/punica_wrapper/punica_hpu.py#L46) As suggested by the vllm team this PR sets `long_lora_context` to None while calling `convert_mapping`. This avoids HPU specific conditions inside `convert_mapping` and explicitly handles HPU long-lora logic inside overrided `_update_base_metadata`. Co-authored-by: Vivek Goel --- vllm/lora/punica_wrapper/punica_hpu.py | 59 +++++++++++++++++++++++++- vllm/lora/punica_wrapper/utils.py | 26 +++--------- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py index d9c4f44a1c282..3661a7214648a 100644 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -1,10 +1,18 @@ -from typing import Optional, Tuple, Union, final +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, dispatch_bgmv_linear) from .punica_base import PunicaWrapperBase +from .utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext @final @@ -17,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, max_batches, device) + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, + extra_vocab_size, self.device, None) + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. + if long_lora_context: + index_mapping_indices: List[int] = list( + mapping.index_mapping).copy() + long_lora_offsets: List[int] = [] + for i in range(len(index_mapping_indices)): + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets.append(lora_offset) + long_lora_offsets_tensor = torch.tensor(long_lora_offsets, + device=self.device, + dtype=torch.long) + indices_len[-1] = long_lora_offsets_tensor.shape[-1] + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 4504e19b20816..dbc2d27c597f2 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -1,9 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 + from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch -from vllm.platforms import current_platform - if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping @@ -88,14 +88,10 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - if current_platform.is_hpu(): - long_lora_offsets_list: List[int] = [] - else: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -108,18 +104,10 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: + assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - if current_platform.is_hpu(): - long_lora_offsets_list.append(lora_offset) - else: - assert long_lora_offsets is not None - long_lora_offsets[i] = lora_offset - - if long_lora_context and current_platform.is_hpu(): - long_lora_offsets = torch.tensor(long_lora_offsets_list, - device=device, - dtype=torch.long) + long_lora_offsets[i] = lora_offset indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices,