You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Note: I'll work on seeing if this reproduces with a non-torchchat example.
While working on migrating torchchat's WeightOnlyInt8Quantizer to AO's quantize_(model, int8_weight_only()) API, I ran into issues where values would go to NaN after a few layers if the model's dtype was initially float16. This seems to occur across multiple platforms (tested with MPS, Mac CPU, x86 CPU), so I'm not sure if it's a hardware-specific issue.
Interestingly, setting the model dtype to bfloat16 does not encounter this error.
python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16
You'll notice the model just outputs "!" tokens - representing NaN. If you add a debug hook to the model, you can identify that some values in the intermediate tensors get very close to 0 just before NaN values are detected.
python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16
The text was updated successfully, but these errors were encountered:
I can confirm this. I also noticed it the other day but did not dig deeper.
If the base weights are in float16, int8_weight_only completely breaks the outputs. If the base weights are bfloat16 the output is as expected in inference only mode.
vmpuri
changed the title
[Needs more investigation] int8_weight_only via quantize_() API results in NaN values across multiple CPU architectures
[Needs more investigation] int8_weight_only via quantize_() API on torch.float16 models results in NaN values across multiple CPU architectures
Feb 5, 2025
Note: I'll work on seeing if this reproduces with a non-torchchat example.
While working on migrating torchchat's
WeightOnlyInt8Quantizer
to AO'squantize_(model, int8_weight_only())
API, I ran into issues where values would go to NaN after a few layers if the model's dtype was initiallyfloat16
. This seems to occur across multiple platforms (tested with MPS, Mac CPU, x86 CPU), so I'm not sure if it's a hardware-specific issue.Interestingly, setting the model dtype to
bfloat16
does not encounter this error.To repro, you can check out this PR with the migration in torchchat
and run a model using:
You'll notice the model just outputs "!" tokens - representing NaN. If you add a debug hook to the model, you can identify that some values in the intermediate tensors get very close to 0 just before NaN values are detected.
The text was updated successfully, but these errors were encountered: