Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Remove extra pickle serialization for gRPCRequest #49943

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Awaitable, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional

from starlette.types import Scope

Expand Down Expand Up @@ -574,7 +574,7 @@ class MultiplexedReplicaInfo:
class gRPCRequest:
"""Sent from the GRPC proxy to replicas on both unary and streaming codepaths."""

grpc_user_request: bytes
user_request_proto: Any


class RequestProtocol(str, Enum):
Expand Down
13 changes: 3 additions & 10 deletions python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,8 @@ async def send_request_to_replica(
proxy_request: ProxyRequest,
app_is_cross_language: bool = False,
) -> ResponseGenerator:
handle_arg = proxy_request.request_object()
response_generator = ProxyResponseGenerator(
# NOTE(edoakes): it's important that the request is sent as raw bytes to
# skip the Ray cloudpickle serialization codepath for performance.
handle.remote(pickle.dumps(handle_arg)),
handle.remote(proxy_request.serialized_replica_arg()),
timeout_s=self.request_timeout_s,
)

Expand Down Expand Up @@ -956,12 +953,8 @@ async def send_request_to_replica(
# Response is returned as raw bytes, convert it to ASGI messages.
result_callback = convert_object_to_asgi_messages
else:
# NOTE(edoakes): it's important that the request is sent as raw bytes to
# skip the Ray cloudpickle serialization codepath for performance.
handle_arg_bytes = pickle.dumps(
proxy_request.request_object(
proxy_actor_name=self.self_actor_name,
)
handle_arg_bytes = proxy_request.serialized_replica_arg(
proxy_actor_name=self.self_actor_name,
)
# Messages are returned as pickled dictionaries.
result_callback = pickle.loads
Expand Down
30 changes: 13 additions & 17 deletions python/ray/serve/_private/proxy_request_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,14 @@ def set_path(self, path: str):
def set_root_path(self, root_path: str):
self.scope["root_path"] = root_path

def request_object(
self,
proxy_actor_name: str,
) -> StreamingHTTPRequest:
return StreamingHTTPRequest(
asgi_scope=self.scope,
proxy_actor_name=proxy_actor_name,
def serialized_replica_arg(self, proxy_actor_name: str) -> bytes:
# NOTE(edoakes): it's important that the request is sent as raw bytes to
# skip the Ray cloudpickle serialization codepath for performance.
return pickle.dumps(
StreamingHTTPRequest(
asgi_scope=self.scope,
proxy_actor_name=proxy_actor_name,
)
)


Expand All @@ -115,7 +116,7 @@ def __init__(
service_method: str,
stream: bool,
):
self.request = request_proto
self._request_proto = request_proto
self.context = context
self.service_method = service_method
self.stream = stream
Expand All @@ -131,7 +132,6 @@ def __init__(
def setup_variables(self):
if not self.is_route_request and not self.is_health_request:
service_method_split = self.service_method.split("/")
self.request = pickle.dumps(self.request)
self.method_name = service_method_split[-1]
for key, value in self.context.invocation_metadata():
if key == "application":
Expand Down Expand Up @@ -161,20 +161,16 @@ def is_route_request(self) -> bool:
def is_health_request(self) -> bool:
return self.service_method == "/ray.serve.RayServeAPIService/Healthz"

@property
def user_request(self) -> bytes:
return self.request

def send_request_id(self, request_id: str):
# Setting the trailing metadata on the ray_serve_grpc_context object, so it's
# not overriding the ones set from the user and will be sent back to the
# client altogether.
self.ray_serve_grpc_context.set_trailing_metadata([("request_id", request_id)])

def request_object(self) -> gRPCRequest:
return gRPCRequest(
grpc_user_request=self.user_request,
)
def serialized_replica_arg(self) -> bytes:
# NOTE(edoakes): it's important that the request is sent as raw bytes to
# skip the Ray cloudpickle serialization codepath for performance.
return pickle.dumps(gRPCRequest(user_request_proto=self._request_proto))


@dataclass(frozen=True)
Expand Down
10 changes: 5 additions & 5 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,17 +1363,18 @@ def _prepare_args_for_grpc_request(
request_metadata: RequestMetadata,
user_method_params: Dict[str, inspect.Parameter],
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Prepare arguments for a user method handling a gRPC request.
"""Prepare args and kwargs for a user method handling a gRPC request.

Returns (request_args, request_kwargs).
The sole argument is always the user request proto.

If the method has a "context" kwarg, we pass the gRPC context, else no kwargs.
"""
request_args = (pickle.loads(request.grpc_user_request),)
if GRPC_CONTEXT_ARG_NAME in user_method_params:
request_kwargs = {GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context}
else:
request_kwargs = {}

return request_args, request_kwargs
return (request.user_request_proto,), request_kwargs

async def _handle_user_method_result(
self,
Expand Down Expand Up @@ -1492,7 +1493,6 @@ async def call_user_method(
generator_result_callback=generator_result_callback,
)
elif request_metadata.is_grpc_request:
# Ensure the request args are a single gRPCRequest object.
assert len(request_args) == 1 and isinstance(
request_args[0], gRPCRequest
)
Expand Down
7 changes: 4 additions & 3 deletions python/ray/serve/tests/unit/test_proxy_request_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def test_calling_user_defined_method(self):
)
assert isinstance(proxy_request, ProxyRequest)
assert proxy_request.route_path == application
assert pickle.loads(proxy_request.request) == request_proto
assert proxy_request.method_name == method_name
assert proxy_request.app_name == application
assert proxy_request.request_id == request_id
Expand All @@ -232,9 +231,11 @@ def test_calling_user_defined_method(self):
("request_id", request_id)
]

request_object = proxy_request.request_object()
serialized_arg = proxy_request.serialized_replica_arg()
assert isinstance(serialized_arg, bytes)
request_object = pickle.loads(serialized_arg)
assert isinstance(request_object, gRPCRequest)
assert pickle.loads(request_object.grpc_user_request) == request_proto
assert request_object.user_request_proto == request_proto


if __name__ == "__main__":
Expand Down
9 changes: 2 additions & 7 deletions python/ray/serve/tests/unit/test_user_callable_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,7 @@ def test_grpc_unary_request(run_sync_methods_in_threadpool: bool):
)
user_callable_wrapper.initialize_callable().result()

grpc_request = gRPCRequest(
pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world"))
)

grpc_request = gRPCRequest(serve_pb2.UserDefinedResponse(greeting="world"))
request_metadata = _make_request_metadata(call_method="greet", is_grpc_request=True)
_, result_bytes = user_callable_wrapper.call_user_method(
request_metadata, (grpc_request,), dict()
Expand All @@ -579,9 +576,7 @@ def test_grpc_streaming_request(run_sync_methods_in_threadpool: bool):
)
user_callable_wrapper.initialize_callable()

grpc_request = gRPCRequest(
pickle.dumps(serve_pb2.UserDefinedResponse(greeting="world"))
)
grpc_request = gRPCRequest(serve_pb2.UserDefinedResponse(greeting="world"))

result_list = []

Expand Down
Loading