diff --git a/llava/train/train.py b/llava/train/train.py index 00235af47..4aa12c3d8 100755 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -1224,6 +1224,7 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args): customized_kwargs.update(bnb_model_from_pretrained_args) overwrite_config = {} + cfg_pretrained = None if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) overwrite_config["rope_scaling"] = { @@ -1246,6 +1247,9 @@ def get_model(model_args, training_args, bnb_model_from_pretrained_args): overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode if overwrite_config: + if cfg_pretrained is None: + cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) + rank0_print(f"Overwriting config with {overwrite_config}") for k, v in overwrite_config.items(): setattr(cfg_pretrained, k, v)