Skip to content

Commit

Permalink
Load balancing and multiple replicas (#147)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
tohtana and mrwyattii authored Mar 1, 2023
1 parent 6116e98 commit 9ec2f12
Show file tree
Hide file tree
Showing 16 changed files with 738 additions and 210 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401,mii/grpc_related/proto/modelresponse_pb2.py:F821']
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401,mii/grpc_related/proto/modelresponse_pb2.py:F821,F401']
2 changes: 1 addition & 1 deletion mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .constants import DeploymentType, Tasks
from .aml_related.utils import aml_output_path

from .config import MIIConfig
from .config import MIIConfig, LoadBalancerConfig
from .grpc_related.proto import modelresponse_pb2_grpc

__version__ = "0.0.0"
Expand Down
123 changes: 71 additions & 52 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,102 +4,117 @@
import asyncio
import grpc
import mii
from mii.utils import get_num_gpus
from mii.grpc_related.proto import modelresponse_pb2_grpc
from mii.utils import get_task
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.method_table import GRPC_METHOD_TABLE


def _get_deployment_info(deployment_name):
configs = mii.utils.import_score_file(deployment_name).configs
task = configs[mii.constants.TASK_NAME_KEY]
mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY]
mii_configs = mii.config.MIIConfig(**mii_configs_dict)

assert task is not None, "The task name should be set before calling init"
return task, mii_configs


def mii_query_handle(deployment_name):
"""Get a query handle for a local deployment:
mii/examples/local/gpt2-query-example.py
mii/examples/local/roberta-qa-query-example.py
Arguments:
deployment_name: Name of the deployment. Used as an identifier for posting queries for ``LOCAL`` deployment.
Returns:
query_handle: A query handle with a single method `.query(request_dictionary)` using which queries can be sent to the model.
"""
task_name, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.enable_load_balancing:
return MIIClient(task_name, "localhost", mii_configs.port_number)
else:
return MIITensorParallelClient(
task_name,
"localhost",
[mii_configs.port_number + i for i in range(mii_configs.tensor_parallel)])

configs = mii.utils.import_score_file(deployment_name).configs

task = configs[mii.constants.TASK_NAME_KEY]
def create_channel(host, port):
return grpc.aio.insecure_channel(f'{host}:{port}',
options=[('grpc.max_send_message_length',
GRPC_MAX_MSG_SIZE),
('grpc.max_receive_message_length',
GRPC_MAX_MSG_SIZE)])

assert task is not None, "The task name should be set before calling init"

return mii.MIIClient(task, mii_configs=configs[mii.constants.MII_CONFIGS_KEY])
class MIIClient():
"""
Client to send queries to a single endpoint.
"""
def __init__(self, task_name, host, port):
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.task = get_task(task_name)

async def _request_async_response(self, request_dict, **query_kwargs):
if self.task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {self.task}")

class MIIClient():
'''Setup the client for the model'''
def __init__(self, task_name, mii_configs={}):
conversions = GRPC_METHOD_TABLE[self.task]
proto_request = conversions["pack_request_to_proto"](request_dict,
**query_kwargs)
proto_response = await getattr(self.stub, conversions["method"])(proto_request)
return conversions["unpack_response_from_proto"](
proto_response
) if "unpack_response_from_proto" in conversions else proto_response

mii_configs = mii.config.MIIConfig(**mii_configs)
def query(self, request_dict, **query_kwargs):
return self.asyncio_loop.run_until_complete(
self._request_async_response(request_dict,
**query_kwargs))

self.task = mii.utils.get_task(task_name)
async def terminate_async(self):
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())

self.num_gpus = get_num_gpus(mii_configs)
assert self.num_gpus > 0, "GPU count must be greater than 0"
def terminate(self):
self.asyncio_loop.run_until_complete(self.terminate_async())

self.port_number = mii_configs.port_number

self.stubs = []
class MIITensorParallelClient():
"""
Client to send queries to multiple endpoints in parallel.
This is used to call multiple servers deployed for tensor parallelism.
"""
def __init__(self, task_name, host, ports):
self.task = get_task(task_name)
self.clients = [MIIClient(task_name, host, port) for port in ports]
self.asyncio_loop = asyncio.get_event_loop()
self._initialize_grpc_client()

def _initialize_grpc_client(self):
channels = []
for i in range(self.num_gpus):
channel = grpc.aio.insecure_channel(f'localhost:{self.port_number + i}',
options=[
('grpc.max_send_message_length',
GRPC_MAX_MSG_SIZE),
('grpc.max_receive_message_length',
GRPC_MAX_MSG_SIZE)
])
stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
channels.append(channel)
self.stubs.append(stub)

# runs task in parallel and return the result from the first task
async def _query_in_tensor_parallel(self, request_string, query_kwargs):
responses = []
for i in range(self.num_gpus):
for client in self.clients:
responses.append(
self.asyncio_loop.create_task(
self._request_async_response(i,
request_string,
query_kwargs)))
client._request_async_response(request_string,
**query_kwargs)))

await responses[0]

return responses[0]

async def _request_async_response(self, stub_id, request_dict, query_kwargs):
if self.task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {self.task}")

conversions = GRPC_METHOD_TABLE[self.task]
proto_request = conversions["pack_request_to_proto"](request_dict,
**query_kwargs)
proto_response = await getattr(self.stubs[stub_id],
conversions["method"])(proto_request)
return conversions["unpack_response_from_proto"](
proto_response
) if "unpack_response_from_proto" in conversions else proto_response

def query(self, request_dict, **query_kwargs):
"""Query a local deployment:
mii/examples/local/gpt2-query-example.py
mii/examples/local/roberta-qa-query-example.py
Arguments:
request_dict: A task specific request dictionary consistinging of the inputs to the models
request_dict: A task specific request dictionary consisting of the inputs to the models
query_kwargs: additional query parameters for the model
Returns:
Expand All @@ -109,5 +124,9 @@ def query(self, request_dict, **query_kwargs):
self._query_in_tensor_parallel(request_dict,
query_kwargs))
ret = response.result()

return ret

def terminate(self):
"""Terminates the deployment"""
for client in self.clients:
client.terminate()
25 changes: 25 additions & 0 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from enum import Enum
from pydantic import BaseModel, validator

from deepspeed.launcher.runner import DLTS_HOSTFILE


class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
Expand Down Expand Up @@ -45,6 +47,9 @@ class MIIConfig(BaseModel):
profile_model_time: bool = False
skip_model_check: bool = False
max_tokens: int = 1024
enable_load_balancing: bool = False
replica_num: int = 1
hostfile: str = DLTS_HOSTFILE

@validator("deploy_rank")
def deploy_valid(cls, field_value, values):
Expand Down Expand Up @@ -85,3 +90,23 @@ class Config:
use_enum_values = True
extra = 'forbid'
json_encoders = {torch.dtype: lambda x: str(x)}


class ReplicaConfig(BaseModel):
hostname: str = ""
tensor_parallel_ports: List[int] = []
torch_dist_port: int = None
gpu_indices: List[int] = []

class Config:
validate_all = True
validate_assignment = True


class LoadBalancerConfig(BaseModel):
port: int = None
replica_configs: List[ReplicaConfig] = []

class Config:
validate_all = True
validate_assignment = True
7 changes: 7 additions & 0 deletions mii/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ModelProvider(enum.Enum):
MODEL_NAME_KEY = 'model_name'
TASK_NAME_KEY = 'task_name'
MODEL_PATH_KEY = 'model_path'
LOAD_BALANCER_CONFIG_KEY = 'load_balancer_config'

ENABLE_DEEPSPEED_KEY = 'ds_optimize'
ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero'
Expand All @@ -109,3 +110,9 @@ class ModelProvider(enum.Enum):
MII_MODEL_PATH_DEFAULT = "/tmp/mii_models"

GRPC_MAX_MSG_SIZE = 2**27 # ~100MB

TERMINATE_METHOD = "Terminate"

LB_MAX_WORKER_THREADS = 32

SERVER_SHUTDOWN_TIMEOUT = 10
65 changes: 64 additions & 1 deletion mii/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import mii

from deepspeed.launcher.runner import fetch_hostfile

from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT
from .utils import logger
from .models.score import create_score_file
from .config import ReplicaConfig, LoadBalancerConfig


def deploy(task,
Expand Down Expand Up @@ -98,6 +101,29 @@ def deploy(task,
elif model_path is None and deployment_type == DeploymentType.AML:
model_path = "model"

# add fields for replica deployment
lb_config = None
if mii_config.enable_load_balancing:
replica_pool = _allocate_processes(mii_config.hostfile,
mii_config.tensor_parallel,
mii_config.replica_num)
replica_configs = []
for i, (hostname, gpu_indices) in enumerate(replica_pool):
# Reserver port for a LB proxy when replication is enabled
port_offset = 1 if mii_config.replica_num > 1 else 0
base_port = mii_config.port_number + i * mii_config.tensor_parallel + port_offset
tensor_parallel_ports = list(
range(base_port,
base_port + mii_config.tensor_parallel))
torch_dist_port = mii_config.torch_dist_port + i
replica_configs.append(
ReplicaConfig(hostname=hostname,
tensor_parallel_ports=tensor_parallel_ports,
torch_dist_port=torch_dist_port,
gpu_indices=gpu_indices))
lb_config = LoadBalancerConfig(port=mii_config.port_number,
replica_configs=replica_configs)

create_score_file(deployment_name=deployment_name,
deployment_type=deployment_type,
task=task,
Expand All @@ -106,7 +132,8 @@ def deploy(task,
ds_zero=enable_zero,
ds_config=ds_config,
mii_config=mii_config,
model_path=model_path)
model_path=model_path,
lb_config=lb_config)

if deployment_type == DeploymentType.AML:
_deploy_aml(deployment_name=deployment_name, model_name=model, version=version)
Expand All @@ -130,3 +157,39 @@ def _deploy_aml(deployment_name, model_name, version):
f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_name)}"
)
print("Please run 'deploy.sh' to bring your deployment online")


def _allocate_processes(hostfile_path, tensor_parallel, num_replicas):
resource_pool = fetch_hostfile(hostfile_path)
assert resource_pool is not None and len(
resource_pool) > 0, f'No hosts found in {hostfile_path}'

replica_pool = []
allocated_num = 0
for host, slots in resource_pool.items():
available_on_host = slots
while available_on_host >= tensor_parallel:
if allocated_num >= num_replicas:
break
if slots < tensor_parallel:
raise ValueError(
f'Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required'
)

allocated_num_on_host = slots - available_on_host
replica_pool.append(
(host,
[
i for i in range(allocated_num_on_host,
allocated_num_on_host + tensor_parallel)
]))
allocated_num += 1

available_on_host -= tensor_parallel

if allocated_num < num_replicas:
raise ValueError(
f'No sufficient GPUs for {num_replicas} replica(s), only {allocated_num} replica(s) can be deployed'
)

return replica_pool
Loading

0 comments on commit 9ec2f12

Please sign in to comment.