Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Spec Decode] Ngram Spec Decode #12193

Merged
merged 92 commits into from
Feb 16, 2025
Merged

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Jan 19, 2025

This PR tries to add ngram spec decode to V1. Design doc: here.
Major changes:

  1. Since we only implement the ngram spec decode, we did not add another scheduler for running the drafting method. We always check if we need to do ngram lookup before calling the scheduler.
  2. Add a new field _spec_token_ids in Request to track speculated tokens.
  3. Changes to model_runner:
    3.1 Change the _prepare_input to also return the logits of speculated tokens.
    3.2 Change the _prepare_input to add speculated tokens as input tokens.
    3.3 Change the execute_model to generate multiple tokens per call. Concretely, it will add more than one tokens to input_batch and req_state.
  4. We only perform spec decode for requests in the running queue.
  5. We only support greedy decoding for now.

What is missing

  • Change scheduling to only propose tokens for decoding requests.
  • Stop checking for spec decode, where mutiple tokens are generated in a single step.
  • For the ngram lookup logic, currently I just append dummy tokens directly instead of performing the lookup. We can move v0's lookup logic here.
  • Check the correctness of this PR with chunked prefill. <-- We only perform spec decode in the decoding phase.
  • More end to end tests & Style.

Tasks out of the scope of this PR

  1. Optimize the performance of ngram lookup.
  2. Support non-greedy decoding.
  3. Add other spec decode methods.

[Update]
I will move the following two features into following PRs:

  1. Guarantee the correctness of prefix caching + spec decode, because it will involve changing the behavior of kv cache manager @comaniac.
  2. Change the scheduling policy to guarantee that at least one token is scheduled for each request. Separate this because it will touch the scheduling code and needs more careful thought/test.

Minor: There is a minimal example/test in tests/v1/e2e/test_basic_specdecode.py. You can check it for the current use and check correctness with pytest -s tests/v1/e2e/test_basic_specdecode.py.

Followup work:

  1. benchmark the flashinfer rejection sampler
  2. ngram kernel
  3. proposer stop checking
  4. KV cache based draft model

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Jan 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LiuXiaoxuanPKU.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 19, 2025
@mergify mergify bot removed the needs-rebase label Jan 20, 2025
@LiuXiaoxuanPKU
Copy link
Collaborator Author

LiuXiaoxuanPKU commented Feb 12, 2025

Some benchmark results:
Model: meta-llama/Meta-Llama-3-8B-Instruct
Hardware: 1xH100
Number of requests: 500 for QPS 1, 10, 1000 for QPS 20
export VLLM_USE_V1=1

Median TTFT/TPOT
Screenshot 2025-02-11 at 11 21 49 PM

The median TTFT/TPOT looks ok for w/o SD ngram branch and main branch, but it does slow down the main a bit. My guess is because of the change for delayed CPU <-> GPU synchronization.

P99 TTFT/TPOT
Screenshot 2025-02-11 at 11 23 43 PM

Detailed setting in doc.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LiuXiaoxuanPKU I think it's looking a lot better!

vllm/v1/request.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiuXiaoxuanPKU Thanks for the work! I think we are almost there. Left some comments, mostly regarding the code style.

vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_manager.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_manager.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
Comment on lines 950 to 952
if seq_len >= req_state.num_tokens:
# We don't rewind the generator state for requests now
# because spec decode only supports greedy decoding for now.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain more?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Feb 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not related, deleted.
But here, my main concern is how does spec decode handle seeding, rejected tokens might affect the generator state? Need more time to think about it and refer to V0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 15, 2025
@WoosukKwon
Copy link
Collaborator

@LiuXiaoxuanPKU Could you take a look at the failed tests?

I'll approve the PR once the tests are green!

Copy link

mergify bot commented Feb 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LiuXiaoxuanPKU.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 15, 2025
@mergify mergify bot removed the needs-rebase label Feb 15, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiuXiaoxuanPKU Thanks for the great work! It's been quite a long journey. Really appreciate for all the work!

@WoosukKwon
Copy link
Collaborator

@LiuXiaoxuanPKU Just wanted to double check. Is the PR ready for merge?

@WoosukKwon WoosukKwon merged commit 80f63a3 into vllm-project:main Feb 16, 2025
40 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants