Skip to content

Commit

Permalink
only override forward if using cuda-graph (#2291)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Sep 14, 2022
1 parent 95d1151 commit cf638be
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 50 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ jobs:
run: |
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout v4.21.2
git rev-parse --short HEAD
pip uninstall --yes transformers
pip install .
Expand Down
46 changes: 9 additions & 37 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,7 @@ def __init__(self,
torch.cuda.set_rng_state(_rng_state.cpu())

if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward
self.module.forward = self.forward
else:
self.module.register_forward_pre_hook(self._pre_forward_hook)
assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism"

def _get_model_config_generate(self, config):
self.config = getattr(self.module, 'config', None) if config is None else config
Expand Down Expand Up @@ -475,14 +472,6 @@ def _convert_to_dtype(self):
elif self.dtype == torch.float:
self.module.float()

def _pre_forward_hook(self, module, *inputs, **kwargs):
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())

def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
Expand Down Expand Up @@ -519,30 +508,13 @@ def forward(self, *inputs, **kwargs):
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""

if self.mp_world_size > 1:
if self.mpu is None:
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
if not input.is_contiguous():
input = input.contiguous()
dist.broadcast(input, 0)
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
if not kwargs[k].is_contiguous():
kwargs[k] = kwargs[k].contiguous()
dist.broadcast(kwargs[k], 0)
outputs = self.model_orig_fwd(*inputs, **kwargs)
else:
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)
#outputs = self.module(*inputs, **kwargs)
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)

return outputs
18 changes: 7 additions & 11 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,13 @@ def test(

@pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task",
[("gpt2",
[("EleutherAI/gpt-neo-1.3B",
"text-generation"),
("EleutherAI/gpt-neox-20b",
"text-generation"),
("bigscience/bloom-3b",
"text-generation")],
ids=["gpt2",
ids=["gpt-neo",
"gpt-neox",
"bloom"])
class TestMPSize(DistributedTest):
Expand All @@ -308,7 +308,6 @@ def test(
self,
model_w_task,
dtype,
enable_cuda_graph,
query,
inf_kwargs,
assert_fn,
Expand All @@ -325,14 +324,11 @@ def test(
pipe = pipeline(task, model=model, device=-1, framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=self.world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
enable_cuda_graph=enable_cuda_graph,
)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=self.world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)
Expand Down

0 comments on commit cf638be

Please sign in to comment.