Skip to content

Commit

Permalink
Merge branch 'habana_main' into add_real_bs
Browse files Browse the repository at this point in the history
  • Loading branch information
kamil-kaczor authored Feb 11, 2025
2 parents b5e445a + 36c7676 commit f4d2173
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 20 deletions.
59 changes: 58 additions & 1 deletion vllm/lora/punica_wrapper/punica_hpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Optional, Tuple, Union, final
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final

import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear)

from .punica_base import PunicaWrapperBase
from .utils import convert_mapping

if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext


@final
Expand All @@ -17,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device)

def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
extra_vocab_size, self.device, None)
# Updating each element in `long_lora_offsets` with `lora_offset` slows
# down perf in HPU due to a series of `strided_insert` ops during lazy
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if long_lora_context:
index_mapping_indices: List[int] = list(
mapping.index_mapping).copy()
long_lora_offsets: List[int] = []
for i in range(len(index_mapping_indices)):
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets.append(lora_offset)
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
device=self.device,
dtype=torch.long)
indices_len[-1] = long_lora_offsets_tensor.shape[-1]

self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len

def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
Expand Down
26 changes: 7 additions & 19 deletions vllm/lora/punica_wrapper/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from vllm.platforms import current_platform

if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -88,14 +88,10 @@ def convert_mapping(
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None

if long_lora_context:
if current_platform.is_hpu():
long_lora_offsets_list: List[int] = []
else:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
Expand All @@ -108,18 +104,10 @@ def convert_mapping(
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
if current_platform.is_hpu():
long_lora_offsets_list.append(lora_offset)
else:
assert long_lora_offsets is not None
long_lora_offsets[i] = lora_offset

if long_lora_context and current_platform.is_hpu():
long_lora_offsets = torch.tensor(long_lora_offsets_list,
device=device,
dtype=torch.long)
long_lora_offsets[i] = lora_offset

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
Expand Down

0 comments on commit f4d2173

Please sign in to comment.