diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index 81cabd32..a005fa25 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -134,6 +134,12 @@ def is_done(self, is_done: bool) -> None: def generated_tokens(self) -> List[torch.Tensor]: return self._generated_tokens + @property + def all_tokens(self) -> List[torch.Tensor]: + return torch.cat([self.prompt_tokens] + + [t.unsqueeze(0) for t in self.generated_tokens], + dim=0) + @property def finish_reason(self) -> GenerationFinishReason: return self._finish_reason diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index a4b49ff3..2724da7b 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -82,6 +82,8 @@ def __init__(self, inference_engine, tokenizer, model_config): self.socket.setsockopt_string(zmq.SUBSCRIBE, "") self.socket.setsockopt(zmq.RCVTIMEO, ZMQ_RECV_TIMEOUT) + self.enable_prefix_cache = self.inference_engine._config.enable_prefix_cache + @cached_property def local_rank(self) -> int: return get_accelerator().current_device() @@ -122,6 +124,9 @@ def generate(self) -> None: # 5. Schedule requests while we wait for the forward pass to finish self._reset_scheduler_bookkeeping() + for r in running_requests: + self.inference_engine.update_prefix_cache(r.uid, r.all_tokens) + # 6. Accumulate generated tokens, check completion, and generate output for r in running_requests.last_in_prompt: r.accumulate_generated_token() @@ -274,24 +279,38 @@ def _schedule_prompts(self, requests: List[Request]) -> None: max_blocks = free_blocks - self.scheduled_req_blocks - if len(r.input_tokens) > 1: + input_tokens = r.input_tokens + if r.seq_length == 0: + cache_hit_length, block_ids = self.inference_engine.lookup_cache(r.input_tokens) + input_tokens = input_tokens[cache_hit_length:] + else: + cache_hit_length = 0 + block_ids = [] + + if len(input_tokens) > 1: # When the KV cache is out of capacity, we release KV cache blocks for a request. # However, we can immediately schedule the request again if we split the request. # So we make sure that we have capacity for the entire prompt (+tokens already generated). - req_tokens, _ = self.inference_engine.query(r.uid, len(r.input_tokens), max_blocks) - if req_tokens < len(r.input_tokens): + req_tokens, _ = self.inference_engine.query(r.uid, len(input_tokens), max_blocks) + if req_tokens < len(input_tokens): break - req_tokens = min(len(r.input_tokens), max_batch_size) + req_tokens = min(len(input_tokens), max_batch_size) req_tokens, req_blocks = self.inference_engine.query(r.uid, req_tokens, max_blocks) if req_tokens <= 0: continue # Decompose the prompt to fit to the max ragged batch size - decomposed = req_tokens < len(r.input_tokens) - remaining_tokens = r.input_tokens[req_tokens:] - r.input_tokens = r.input_tokens[:req_tokens] + if cache_hit_length > 0: + self.inference_engine.setup_cached_sequence(r.uid, + cache_hit_length, + block_ids) + r.seq_length = r.seq_length + cache_hit_length + + decomposed = req_tokens < len(input_tokens) + remaining_tokens = input_tokens[req_tokens:] + r.input_tokens = input_tokens[:req_tokens] r.last_in_prompt = not decomposed # Schedule the request @@ -571,11 +590,15 @@ def __call__(self, if self.is_rank_0: # Rank 0 runs generate() until all responses are returned - while uids_running: + while uids_running \ + or not self.request_queue.empty() \ + or self.scheduled_requests.requests_to_flush.uids: self.generate() while not self.result_queues[self.tid].empty(): uid, response = self._get_response() outputs.append(response) + # We can't directly call flush because the flush request is broadcasted + # to other ranks after taken from the queue self._queue_flush_request(uid) uids_complete_order.append(uid) uids_running.remove(uid)