From 30605dfde0966940cd8161d4641d9082918c0739 Mon Sep 17 00:00:00 2001 From: Li Bo Date: Thu, 16 May 2024 13:43:57 +0800 Subject: [PATCH] Merge pull request #1 from EvolvingLMMs-Lab/py/dev fix resampler bug & Add pos skipping (cherry picked from commit acba85f26cfdac1947f24d535643c9dc62752ab5) --- llava/model/llava_arch.py | 10 ++++++++-- llava/train/train.py | 26 ++++++++++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index b63047073..3601ef6a6 100755 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -28,7 +28,7 @@ from llava.mm_utils import get_anyres_image_grid_shape from llava.utils import rank0_print - +import random class LlavaMetaModel: @@ -421,7 +421,13 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio if _position_ids is None: position_ids = None - + if getattr(self.config, "use_pos_skipping", False) and self.training: + position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add # import pdb; pdb.set_trace() return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels diff --git a/llava/train/train.py b/llava/train/train.py index 4aa12c3d8..bd7d02bcc 100755 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -106,7 +106,9 @@ class ModelArguments: s2: Optional[bool] = field(default=False) s2_scales: Optional[str] = field(default="336,672,1008") - + + use_pos_skipping: Optional[bool] = field(default=False) + pos_skipping_range: Optional[int] = field(default=4096) @dataclass class DataArguments: @@ -1222,11 +1224,24 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args): customized_kwargs = dict() customized_kwargs.update(bnb_model_from_pretrained_args) - - overwrite_config = {} cfg_pretrained = None - if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: + + overwrite_config = {} + if any([ + model_args.rope_scaling_factor is not None, + model_args.rope_scaling_type is not None, + model_args.mm_spatial_pool_stride is not None, + model_args.mm_spatial_pool_out_channels is not None, + model_args.mm_spatial_pool_mode is not None, + model_args.mm_resampler_type is not None + ]): cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) + + if model_args.use_pos_skipping is not None and model_args.pos_skipping_range is not None: + overwrite_config["use_pos_skipping"] = model_args.use_pos_skipping + overwrite_config["pos_skipping_range"] = model_args.pos_skipping_range + + if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: overwrite_config["rope_scaling"] = { "factor": model_args.rope_scaling_factor, "type": model_args.rope_scaling_type, @@ -1247,8 +1262,7 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args): overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode if overwrite_config: - if cfg_pretrained is None: - cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) + assert cfg_pretrained is not None, "cfg_pretrained is None" rank0_print(f"Overwriting config with {overwrite_config}") for k, v in overwrite_config.items():