Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Pytest with a specific config failed after PR #5868 #5949

Open
1 task done
GuangyaoZhang opened this issue Jul 29, 2024 · 0 comments
Open
1 task done

[BUG]: Pytest with a specific config failed after PR #5868 #5949

GuangyaoZhang opened this issue Jul 29, 2024 · 0 comments
Assignees
Labels
bug Something isn't working shardformer

Comments

@GuangyaoZhang
Copy link
Contributor

Is there an existing issue for this bug?

  • I have searched the existing issues

🐛 Describe the bug

Main repo test_shard_llama fails for these configs:

{'tp_size': 2, 
'pp_size': 2, 
'sp_size': 2, 
'num_microbatches': 2, 
'enable_sequence_parallelism': True, 
'sequence_parallelism_mode': 'ring', 
'enable_flash_attention': True, 
'zero_stage': 1, 
'precision': 'fp16', 
'initial_scale': 1}
{'tp_size': 2,
 'sp_size': 2, 
'pp_size': 2, 
'num_microbatches': 2, 
'enable_sequence_parallelism': True, 
'sequence_parallelism_mode': 'split_gather', 
'enable_flash_attention': False, 
'precision': 'fp16', 
'initial_scale': 1}

The failure message is :

E         File "/home/nvme-share/home/zhangguangyao/ColossalAI/colossalai/shardformer/modeling/llama.py", line 530, in forward                
E           query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)                                               
E         File "/home/nvme-share/home/zhangguangyao/hf_transformers/src/transformers/models/llama/modeling_llama.py", line 206, in apply_rotary_pos_emb                                                                                                                                     
E           q_embed = (q * cos) + (rotate_half(q) * sin)                                                                                      
E       RuntimeError: The size of tensor a (16) must match the size of tensor b (8) at non-singleton dimension 2 

I have found out this failure is introduced after PR #5868 merged. Please take a look.

Environment

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working shardformer
Projects
None yet
Development

No branches or pull requests

2 participants