From cf638be99803682933cb4040850765d46832ee78 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 14 Sep 2022 13:51:29 -0700 Subject: [PATCH] only override forward if using cuda-graph (#2291) --- .github/workflows/nv-inference.yml | 2 -- deepspeed/inference/engine.py | 46 +++++--------------------- tests/unit/inference/test_inference.py | 18 ++++------ 3 files changed, 16 insertions(+), 50 deletions(-) diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index dc00682edd46..d2551b0f79d2 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -40,8 +40,6 @@ jobs: run: | git clone https://github.com/huggingface/transformers cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - git checkout v4.21.2 git rev-parse --short HEAD pip uninstall --yes transformers pip install . diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index a4b57a05f37b..81566e7165c5 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -162,10 +162,7 @@ def __init__(self, torch.cuda.set_rng_state(_rng_state.cpu()) if self.mp_world_size > 1: - self.model_orig_fwd = self.module.forward - self.module.forward = self.forward - else: - self.module.register_forward_pre_hook(self._pre_forward_hook) + assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism" def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config @@ -475,14 +472,6 @@ def _convert_to_dtype(self): elif self.dtype == torch.float: self.module.float() - def _pre_forward_hook(self, module, *inputs, **kwargs): - for input in inputs: - if torch.is_tensor(input): - input = input.to(torch.cuda.current_device()) - for k in kwargs: - if torch.is_tensor(kwargs[k]): - kwargs[k] = kwargs[k].to(torch.cuda.current_device()) - def _create_cuda_graph(self, *inputs, **kwargs): # warmup to create the workspace and cublas handle cuda_stream = torch.cuda.Stream() @@ -519,30 +508,13 @@ def forward(self, *inputs, **kwargs): *inputs: Variable length input list **kwargs: variable length keyword arguments """ - - if self.mp_world_size > 1: - if self.mpu is None: - for input in inputs: - if torch.is_tensor(input): - input = input.to(torch.cuda.current_device()) - if not input.is_contiguous(): - input = input.contiguous() - dist.broadcast(input, 0) - for k in kwargs: - if torch.is_tensor(kwargs[k]): - kwargs[k] = kwargs[k].to(torch.cuda.current_device()) - if not kwargs[k].is_contiguous(): - kwargs[k] = kwargs[k].contiguous() - dist.broadcast(kwargs[k], 0) - outputs = self.model_orig_fwd(*inputs, **kwargs) - else: - if self.enable_cuda_graph: - if self.cuda_graph_created: - outputs = self._graph_replay(*inputs, **kwargs) - else: - self._create_cuda_graph(*inputs, **kwargs) - outputs = self._graph_replay(*inputs, **kwargs) + if self.enable_cuda_graph: + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) else: - outputs = self.module(*inputs, **kwargs) - #outputs = self.module(*inputs, **kwargs) + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + else: + outputs = self.module(*inputs, **kwargs) + return outputs diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 1b1efdc595fe..04e9320accc8 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -292,13 +292,13 @@ def test( @pytest.mark.seq_inference @pytest.mark.parametrize("model_w_task", - [("gpt2", + [("EleutherAI/gpt-neo-1.3B", "text-generation"), ("EleutherAI/gpt-neox-20b", "text-generation"), ("bigscience/bloom-3b", "text-generation")], - ids=["gpt2", + ids=["gpt-neo", "gpt-neox", "bloom"]) class TestMPSize(DistributedTest): @@ -308,7 +308,6 @@ def test( self, model_w_task, dtype, - enable_cuda_graph, query, inf_kwargs, assert_fn, @@ -325,14 +324,11 @@ def test( pipe = pipeline(task, model=model, device=-1, framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference( - pipe.model, - mp_size=self.world_size, - dtype=dtype, - replace_method="auto", - replace_with_kernel_inject=True, - enable_cuda_graph=enable_cuda_graph, - ) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=self.world_size, + dtype=dtype, + replace_method="auto", + replace_with_kernel_inject=True) # Switch device to GPU so that input tensors are not on CPU pipe.device = torch.device(f"cuda:{local_rank}") ds_output = pipe(query, **inf_kwargs)