Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError('initial value for logits error [FIXED] #1248

Open
daegonYu opened this issue Nov 6, 2024 · 10 comments
Open

AssertionError('initial value for logits error [FIXED] #1248

daegonYu opened this issue Nov 6, 2024 · 10 comments
Labels
fixed - pending confirmation Fixed, waiting for confirmation from poster URGENT BUG Urgent bug

Comments

@daegonYu
Copy link

daegonYu commented Nov 6, 2024

{
	"name": "CompilationError",
	"message": "at 53:4:
    loss_ptr      += row_idx
    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
    labels_ptr    += row_idx

    col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < VOCAB_SIZE

    label_idx = tl.load(labels_ptr).to(tl.int32)
    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\"))

    # Go logit scaling for Cohere: t * x
    if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
    ^
AssertionError('initial value for `logits` is of type <[65536], bf16>, but the then block redefines it as <[65536], fp32>')",
	"stack": "---------------------------------------------------------------------------
CompilationError                          Traceback (most recent call last)
Cell In[28], line 1
----> 1 trainer_stats = trainer.train()

File <string>:156, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

File <string>:380, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File <string>:31, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/_utils.py:945, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
    943     pass
    944 pass
--> 945 return self._old_compute_loss(model, inputs, *args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3631         loss_kwargs[\"num_items_in_batch\"] = num_items_in_batch
   3632     inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
   3634 # Save past state if it exists
   3635 # TODO: this needs to be fixed and made cleaner later.
   3636 if self.args.past_index >= 0:

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    822 def forward(*args, **kwargs):
--> 823     return model_forward(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    810 def __call__(self, *args, **kwargs):
--> 811     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_compile.py:32, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     29     disable_fn = torch._dynamo.disable(fn, recursive)
     30     fn.__dynamo_disable = disable_fn
---> 32 return disable_fn(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:632, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    630 prior = _maybe_set_eval_frame(callback)
    631 try:
--> 632     return fn(*args, **kwargs)
    633 finally:
    634     _maybe_set_eval_frame(prior)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/llama.py:1046, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, num_logits_to_keep, **kwargs)
   1031 @torch._disable_dynamo
   1032 def PeftModelForCausalLM_fast_forward(
   1033     self,
   (...)
   1044     **kwargs,
   1045 ):
-> 1046     return self.base_model(
   1047         input_ids=input_ids,
   1048         causal_mask=causal_mask,
   1049         attention_mask=attention_mask,
   1050         inputs_embeds=inputs_embeds,
   1051         labels=labels,
   1052         output_attentions=output_attentions,
   1053         output_hidden_states=output_hidden_states,
   1054         return_dict=return_dict,
   1055         num_logits_to_keep=num_logits_to_keep,
   1056         **kwargs,
   1057     )

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/llama.py:987, in CausalLM_fast_forward.<locals>._CausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep, *args, **kwargs)
    984     pass
    986     shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
--> 987     loss = fast_cross_entropy_loss(
    988         logits = shift_logits,
    989         labels = shift_labels,
    990         logit_softcapping = logit_softcapping,
    991         logit_scaling     = logit_scaling,
    992         n_items           = kwargs.get(\"num_items_in_batch\", None) or kwargs.get(\"n_items\", None),
    993     )
    994 else:
    995     if logit_scaling != 0:

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/cross_entropy_loss.py:386, in fast_cross_entropy_loss(logits, labels, logit_softcapping, logit_scaling, n_items)
    383 batch, seq_len, d = logits.shape
    384 assert(labels.shape == (batch, seq_len))
--> 386 loss = Fast_CrossEntropyLoss.apply(
    387     logits.view(batch*seq_len, d),
    388     labels.view(-1),
    389     logit_softcapping,
    390     logit_scaling,
    391 )
    392 if n_items is None:
    393     n_items = torch.count_nonzero(labels != -100)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
    572 if not torch._C._are_functorch_transforms_active():
    573     # See NOTE: [functorch vjp and autograd interaction]
    574     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575     return super().apply(*args, **kwargs)  # type: ignore[misc]
    577 if not is_setup_ctx_defined:
    578     raise RuntimeError(
    579         \"In order to use an autograd.Function with functorch transforms \"
    580         \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
    581         \"staticmethod. For more details, please see \"
    582         \"https://pytorch.org/docs/main/notes/extending.func.html\"
    583     )

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/kernels/cross_entropy_loss.py:311, in Fast_CrossEntropyLoss.forward(ctx, logits, labels, logit_softcapping, logit_scaling)
    307 else:
    308     # For large vocabs > 65336 like Gemma 256K
    309     logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = \"cuda:0\")
--> 311     _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
    312         logits, logits.stride(0),
    313         losses,
    314         logsumexp,
    315         labels,
    316         VOCAB_SIZE       = vocab_size,
    317         N_CHUNKS         = n_chunks,
    318         BLOCK_SIZE       = MAX_FUSED_SIZE,
    319         DO_SOFTCAPPING   = DO_SOFTCAPPING,
    320         SOFTCAP          = logit_softcapping,
    321         DO_LOGIT_SCALING = DO_LOGIT_SCALING,
    322         LOGIT_SCALE      = logit_scaling,
    323         num_warps        = 32,
    324     )
    325     # logsumexp(chunked_logsumexp) - x
    326     # Do the -x separately
    327     logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/jit.py:345, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    339 def __getitem__(self, grid) -> T:
    340     \"\"\"
    341     A JIT function is launched with: fn[grid](*args, **kwargs).
    342     Hence JITFunction.__getitem__ returns a callable proxy that
    343     memorizes the grid.
    344     \"\"\"
--> 345     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/autotuner.py:338, in Heuristics.run(self, *args, **kwargs)
    336 for v, heur in self.values.items():
    337     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 338 return self.fn.run(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/runtime/jit.py:662, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    660     # compile the kernel
    661     src = self.ASTSource(self, signature, constants, configs[0])
--> 662     kernel = self.compile(
    663         src,
    664         target=target,
    665         options=options.__dict__,
    666     )
    667     self.cache[device][key] = kernel
    669 # Check that used global values have not changed.

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/compiler/compiler.py:276, in compile(src, target, options)
    274 codegen_fns = backend.get_codegen_implementation()
    275 try:
--> 276     module = src.make_ir(options, codegen_fns, context)
    277 except Exception as e:
    278     filter_traceback(e)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/triton/compiler/compiler.py:113, in ASTSource.make_ir(self, options, codegen_fns, context)
    112 def make_ir(self, options, codegen_fns, context):
--> 113     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)

CompilationError: at 53:4:
    loss_ptr      += row_idx
    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
    labels_ptr    += row_idx

    col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < VOCAB_SIZE

    label_idx = tl.load(labels_ptr).to(tl.int32)
    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float(\"inf\"))

    # Go logit scaling for Cohere: t * x
    if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
    ^
AssertionError('initial value for `logits` is of type <[65536], bf16>, but the then block redefines it as <[65536], fp32>')"
}

Package Version


accelerate 1.1.0
aiohappyeyeballs 2.4.3
aiohttp 3.10.10
aiosignal 1.3.1
asttokens 2.4.1
attrs 24.2.0
bitsandbytes 0.44.1
certifi 2024.8.30
charset-normalizer 3.4.0
comm 0.2.2
datasets 3.1.0
debugpy 1.8.7
decorator 5.1.1
dill 0.3.8
docstring_parser 0.16
exceptiongroup 1.2.2
executing 2.1.0
filelock 3.13.1
frozenlist 1.5.0
fsspec 2024.9.0
gmpy2 2.1.2
hf_transfer 0.1.8
huggingface-hub 0.26.2
idna 3.10
importlib_metadata 8.5.0
ipykernel 6.29.5
ipython 8.29.0
jedi 0.19.1
Jinja2 3.1.4
jupyter_client 8.6.3
jupyter_core 5.7.2
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib-inline 0.1.7
mdurl 0.1.2
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.16
nest_asyncio 1.6.0
networkx 3.3
numpy 2.1.3
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
packaging 24.1
pandas 2.2.3
parso 0.8.4
peft 0.13.2
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.2.0
pip 24.3.1
platformdirs 4.3.6
prompt_toolkit 3.0.48
propcache 0.2.0
protobuf 3.20.3
psutil 6.1.0
ptyprocess 0.7.0
pure_eval 0.2.3
pyarrow 18.0.0
Pygments 2.18.0
python-dateutil 2.9.0
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
regex 2024.9.11
requests 2.32.3
rich 13.9.4
safetensors 0.4.5
sentencepiece 0.2.0
setuptools 75.3.0
shtab 1.7.1
six 1.16.0
stack-data 0.6.2
sympy 1.13.1
tokenizers 0.20.3
torch 2.5.1
tornado 6.4.1
tqdm 4.66.6
traitlets 5.14.3
transformers 4.46.2
triton 3.1.0
trl 0.12.0
typing_extensions 4.12.2
tyro 0.8.14
tzdata 2024.2
unsloth 2024.11.1
unsloth_zoo 2024.11.1
urllib3 2.2.3
wcwidth 0.2.13
wheel 0.44.0
xformers 0.0.28.post3
xxhash 3.5.0
yarl 1.17.1
zipp 3.20.2

@giantvision
Copy link

Ran into the same problem. Temporary solution: rollback version.

@danielhanchen danielhanchen mentioned this issue Nov 6, 2024
@danielhanchen danielhanchen changed the title Errors that occur while learning with unsloth 2024.11.1 version [FIXED] AssertionError('initial value for logits error Nov 6, 2024
@danielhanchen danielhanchen added fixed - pending confirmation Fixed, waiting for confirmation from poster URGENT BUG Urgent bug labels Nov 6, 2024
@danielhanchen
Copy link
Contributor

danielhanchen commented Nov 6, 2024

Apologies just fixed it - If on Colab / Kaggle, disconnect and delete runtime and rerun the notebook. For local installations, please update Unsloth without new dependencies via:

pip install --upgrade --no-cache-dir --no-deps unsloth unsloth-zoo

@danielhanchen danielhanchen pinned this issue Nov 6, 2024
@daegonYu
Copy link
Author

daegonYu commented Nov 6, 2024

Then the following error appears.


{
	"name": "RuntimeError",
	"message": "compiled_autograd.enable() requires no threads in backwards()",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[29], line 1
----> 1 trainer_stats = trainer.train()

File <string>:156, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

File <string>:380, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File <string>:64, in _unsloth_training_step(***failed resolving arguments***)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/accelerator.py:2241, in Accelerator.backward(self, loss, **kwargs)
   2239     self.lomo_backward(loss, learning_rate)
   2240 else:
-> 2241     loss.backward(**kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_tensor.py:581, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    571 if has_torch_function_unary(self):
    572     return handle_torch_function(
    573         Tensor.backward,
    574         (self,),
   (...)
    579         inputs=inputs,
    580     )
--> 581 torch.autograd.backward(
    582     self, gradient, retain_graph, create_graph, inputs=inputs
    583 )

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    342     retain_graph = create_graph
    344 # The reason we repeat the same comment below is that
    345 # some Python versions print out the first line of a multi-line function
    346 # calls in the traceback and some print out the last line
--> 347 _engine_run_backward(
    348     tensors,
    349     grad_tensors_,
    350     retain_graph,
    351     create_graph,
    352     inputs,
    353     allow_unreachable=True,
    354     accumulate_grad=True,
    355 )

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/autograd/graph.py:825, in _engine_run_backward(t_outputs, *args, **kwargs)
    823     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    824 try:
--> 825     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    826         t_outputs, *args, **kwargs
    827     )  # Calls into the C++ engine to run the backward pass
    828 finally:
    829     if attach_logging_hooks:

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/autograd/function.py:307, in BackwardCFunction.apply(self, *args)
    301     raise RuntimeError(
    302         \"Implementing both 'backward' and 'vjp' for a custom \"
    303         \"Function is not allowed. You should only implement one \"
    304         \"of them.\"
    305     )
    306 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 307 return user_fn(self, *args)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/amp/autocast_mode.py:511, in custom_bwd.<locals>.decorate_bwd(*args, **kwargs)
    504 @functools.wraps(bwd)
    505 def decorate_bwd(*args, **kwargs):
    506     with autocast(
    507         device_type=device_type,
    508         enabled=args[0]._fwd_used_autocast,
    509         dtype=args[0]._dtype,
    510     ):
--> 511         return bwd(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth_zoo/gradient_checkpointing.py:156, in Unsloth_Offloaded_Gradient_Checkpointer.backward(ctx, dY)
    154 hidden_states.requires_grad_(True)
    155 with torch.enable_grad():
--> 156     (output,) = ctx.forward_function(hidden_states, *ctx.args)
    157 torch.autograd.backward(output, dY)
    158 return (None, hidden_states.grad,) + (None,)*len(ctx.args)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/unsloth/models/gemma2.py:210, in Gemma2DecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
    208 else:
    209     residual = hidden_states
--> 210     hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
    211     hidden_states, self_attn_weights, present_key_value = self.self_attn(
    212         hidden_states=hidden_states,
    213         causal_mask=causal_mask,
   (...)
    219         padding_mask=padding_mask,
    220     )
    221     hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    448     else:
    449         return fn(*args, **kwargs)
--> 451 cleanups = [enter() for enter in self.enter_exit_hooks]
    452 prior = _maybe_set_eval_frame(callback)
    454 # Ensure that if an assertion occurs after graph pushes
    455 # something onto the DynamicLayerStack then we pop it off (the
    456 # constructed graph code isn't guarded with try/finally).
    457 #
    458 # This used to be a context but putting a `with` here is a noticible
    459 # perf regression (#126293)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in <listcomp>(.0)
    448     else:
    449         return fn(*args, **kwargs)
--> 451 cleanups = [enter() for enter in self.enter_exit_hooks]
    452 prior = _maybe_set_eval_frame(callback)
    454 # Ensure that if an assertion occurs after graph pushes
    455 # something onto the DynamicLayerStack then we pop it off (the
    456 # constructed graph code isn't guarded with try/finally).
    457 #
    458 # This used to be a context but putting a `with` here is a noticible
    459 # perf regression (#126293)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:564, in OptimizeContext.__init__.<locals>.call_compiled_autograd()
    562 compiler_fn = rebuild_ctx()
    563 ctx = torch._dynamo.compiled_autograd.enable(compiler_fn)
--> 564 ctx.__enter__()
    565 return functools.partial(ctx.__exit__, None, None, None)

File ~/anaconda3/envs/unsloth_env/lib/python3.11/contextlib.py:137, in _GeneratorContextManager.__enter__(self)
    135 del self.args, self.kwds, self.func
    136 try:
--> 137     return next(self.gen)
    138 except StopIteration:
    139     raise RuntimeError(\"generator didn't yield\") from None

File ~/anaconda3/envs/unsloth_env/lib/python3.11/site-packages/torch/_dynamo/compiled_autograd.py:499, in enable(compiler_fn)
    497 @contextlib.contextmanager
    498 def enable(compiler_fn):
--> 499     prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
    500         functools.partial(AutogradCompilerInstance, compiler_fn)
    501     )
    502     if snapshot_verbose_logging_enabled():
    503         torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)

RuntimeError: compiled_autograd.enable() requires no threads in backwards()"
}

package version

unsloth            2024.11.2
unsloth_zoo        2024.11.1
transformers       4.46.2
triton             3.1.0
torch              2.5.1
accelerate         1.1.0
xformers           0.0.28.post3

@remiconnesson
Copy link

Apologies just fixed it - If on Colab / Kaggle, disconnect and delete runtime and rerun the notebook. For local installations, please update Unsloth without new dependencies via:

pip install --upgrade --no-cache-dir --no-deps unsloth unsloth-zoo

It unfortunately didn't work out, the same issue appear

absl-py==2.1.0
accelerate==1.1.0
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aioprometheus==23.12.0
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bitsandbytes==0.44.1
bleach==6.1.0
blinker==1.4
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.1.0
comm==0.2.2
compressed-tensors==0.6.0
cryptography==3.4.8
datasets==3.1.0
dbus-python==1.2.18
debugpy==1.8.5
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distro==1.7.0
docstring_parser==0.16
einops==0.8.0
entrypoints==0.4
executing==2.1.0
fastapi==0.115.4
fastjsonschema==2.20.0
filelock==3.13.1
fqdn==1.5.1
frozenlist==1.5.0
fsspec==2024.2.0
gguf==0.10.0
grpcio==1.67.1
h11==0.14.0
hf_transfer==0.1.8
httpcore==1.0.5
httplib2==0.20.2
httptools==0.6.4
httpx==0.27.2
huggingface-hub==0.26.2
idna==3.10
importlib-metadata==4.6.4
interegular==0.3.3
ipykernel==6.29.5
ipython==8.27.0
ipython-genutils==0.2.0
ipywidgets==8.1.5
isoduration==20.11.0
jedi==0.19.1
jeepney==0.7.1
Jinja2==3.1.3
jiter==0.7.0
joblib==1.4.2
json5==0.9.25
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-archive==3.4.0
jupyter-events==0.10.0
jupyter-highlight-selected-word==0.2.0
jupyter-lsp==2.2.5
jupyter_client==7.4.9
jupyter_contrib_core==0.4.2
jupyter_contrib_nbextensions==0.7.0
jupyter_core==5.7.2
jupyter_nbextensions_configurator==0.6.4
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
keyring==23.5.0
lark==1.2.2
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
llvmlite==0.43.0
lm-format-enforcer==0.10.6
lxml==5.3.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mdurl==0.1.2
mistral_common==1.4.4
mistune==3.0.2
more-itertools==8.10.0
mpmath==1.3.0
msgpack==1.1.0
msgspec==0.18.6
multidict==6.1.0
multiprocess==0.70.16
nbclassic==1.1.0
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
ninja==1.11.1.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.60.0
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.560.30
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.0
openai==1.54.1
opencv-python-headless==4.10.0.84
orjson==3.10.11
outlines==0.0.46
overrides==7.7.0
packaging==24.1
pandas==2.2.3
pandocfilters==1.5.1
parso==0.8.4
partial-json-parser==0.2.1.1.post4
peft==0.13.2
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.3.6
prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.21.0
prompt_toolkit==3.0.47
propcache==0.2.0
protobuf==3.20.3
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
py-cpuinfo==9.0.0
pyairports==2.1.1
pyarrow==18.0.0
pycountry==24.6.1
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
Pygments==2.18.0
PyGObject==3.42.1
PyJWT==2.3.0
pyparsing==2.4.7
python-apt==2.4.0+ubuntu4
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.2
PyYAML==6.0.2
pyzmq==24.0.1
quantile-python==1.1
ray==2.38.0
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.9.4
rpds-py==0.20.0
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.2.0
shtab==1.7.1
six==1.16.0
sniffio==1.3.1
soupsieve==2.6
stack-data==0.6.3
starlette==0.41.2
sympy==1.13.1
tensorboard==2.18.0
tensorboard-data-server==0.7.2
terminado==0.18.1
threadpoolctl==3.5.0
tiktoken==0.7.0
tinycss2==1.3.0
tokenizers==0.20.3
torch==2.4.0
torchaudio==2.4.1+cu124
torchvision==0.19.0
tornado==6.4.1
tqdm==4.66.6
traitlets==5.14.3
transformers==4.46.2
triton==3.0.0
trl==0.12.0
types-python-dateutil==2.9.0.20240906
typing_extensions==4.12.2
tyro==0.8.14
tzdata==2024.2
unsloth==2024.11.2
unsloth_zoo==2024.11.1
uri-template==1.3.0
urllib3==2.2.3
uvicorn==0.32.0
uvloop==0.21.0
vllm==0.6.3.post1
wadllib==1.3.6
watchfiles==0.24.0
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
websockets==13.1
Werkzeug==3.1.2
widgetsnbextension==4.0.13
xformers==0.0.27.post2
xxhash==3.5.0
yarl==1.17.1
zipp==1.0.0

@remiconnesson
Copy link

For anyone stuck, rolling back using
pip install --upgrade --no-cache-dir --no-deps unsloth==2024.10.7 unsloth-zoo==2024.10.5 worked for me

@R4ZZ3
Copy link

R4ZZ3 commented Nov 6, 2024

For anyone stuck, rolling back using pip install --upgrade --no-cache-dir --no-deps unsloth==2024.10.7 unsloth-zoo==2024.10.5 worked for me

This worked for me. Thanks!

@danielhanchen
Copy link
Contributor

@R4ZZ3 @remiconnesson @daegonYu Apologies Gemma broke - I just added a fix for this just then - sorry again!!

Please update Unsloth without changing dependencies if on local machines - Colab / Kaggle just delete the runtime and start again

pip uninstall unsloth unsloth-zoo -y && pip install --upgrade --no-cache-dir --no-deps unsloth unsloth-zoo

@CurtiusSimplus
Copy link

RuntimeError: Expected out tensor to have dtype c10::BFloat16, but got float instead

@daegonYu
Copy link
Author

daegonYu commented Nov 7, 2024

The evaluation is not yet done, but the learning is going well now.

Discussions about the evaluation are still ongoing.
#853

Thank you!

@danielhanchen danielhanchen changed the title [FIXED] AssertionError('initial value for logits error AssertionError('initial value for logits error [FIXED] Nov 14, 2024
@danielhanchen
Copy link
Contributor

I uploaded the changes now to Pypi! So pip install --upgrade --no-cache-dir --no-deps unsloth has the latest fixes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fixed - pending confirmation Fixed, waiting for confirmation from poster URGENT BUG Urgent bug
Projects
None yet
Development

No branches or pull requests

6 participants