diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 7e658699b47e..a4386f064567 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -7,46 +7,11 @@ option csharp_namespace = "Microsoft.AutoGen.Contracts"; import "cloudevent.proto"; import "google/protobuf/any.proto"; -message TopicId { - string type = 1; - string source = 2; -} - message AgentId { string type = 1; string key = 2; } -message Payload { - string data_type = 1; - string data_content_type = 2; - bytes data = 3; -} - -message RpcRequest { - string request_id = 1; - optional AgentId source = 2; - AgentId target = 3; - string method = 4; - Payload payload = 5; - map metadata = 6; -} - -message RpcResponse { - string request_id = 1; - Payload payload = 2; - string error = 3; - map metadata = 4; -} - -message Event { - string topic_type = 1; - string topic_source = 2; - optional AgentId source = 3; - Payload payload = 4; - map metadata = 5; -} - message RegisterAgentTypeRequest { string request_id = 1; string type = 2; @@ -115,13 +80,11 @@ message SaveStateResponse { message Message { oneof message { - RpcRequest request = 1; - RpcResponse response = 2; - io.cloudevents.v1.CloudEvent cloudEvent = 3; - RegisterAgentTypeRequest registerAgentTypeRequest = 4; - RegisterAgentTypeResponse registerAgentTypeResponse = 5; - AddSubscriptionRequest addSubscriptionRequest = 6; - AddSubscriptionResponse addSubscriptionResponse = 7; + io.cloudevents.v1.CloudEvent cloudEvent = 1; + RegisterAgentTypeRequest registerAgentTypeRequest = 2; + RegisterAgentTypeResponse registerAgentTypeResponse = 3; + AddSubscriptionRequest addSubscriptionRequest = 4; + AddSubscriptionResponse addSubscriptionResponse = 5; } } diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py index 572592f2d32b..c63075f5aa4c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py @@ -42,9 +42,9 @@ def __init__(self, description: str) -> None: super().__init__(description=description) self._fifo_lock = FIFOLock() - async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: await self._fifo_lock.acquire() try: - return await super().on_message_impl(message, ctx) + await super().on_message_impl(message, ctx) finally: self._fifo_lock.release() diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb index 400985577698..fda8b90f1917 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb @@ -185,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -222,7 +222,6 @@ "await runtime.send_message(\n", " Message(\"Joe, tell me a joke.\"),\n", " recipient=AgentId(joe, \"default\"),\n", - " sender=AgentId(cathy, \"default\"),\n", ")\n", "await runtime.stop_when_idle()" ] diff --git a/python/packages/autogen-core/samples/slow_human_in_loop.py b/python/packages/autogen-core/samples/slow_human_in_loop.py index 9c4476d06b5c..c8a70cca607b 100644 --- a/python/packages/autogen-core/samples/slow_human_in_loop.py +++ b/python/packages/autogen-core/samples/slow_human_in_loop.py @@ -31,7 +31,6 @@ from typing import Any, Mapping, Optional from autogen_core import ( - AgentId, CancellationToken, DefaultTopicId, FunctionCall, @@ -41,7 +40,6 @@ message_handler, type_subscription, ) -from autogen_core.base.intervention import DefaultInterventionHandler from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import ( AssistantMessage, @@ -207,11 +205,11 @@ async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]}) -class NeedsUserInputHandler(DefaultInterventionHandler): +class NeedsUserInputHandler: def __init__(self): self.question_for_user: GetSlowUserMessage | None = None - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: + async def __call__(self, message: Any, message_context: MessageContext) -> Any: if isinstance(message, GetSlowUserMessage): self.question_for_user = message return message @@ -227,11 +225,11 @@ def user_input_content(self) -> str | None: return self.question_for_user.content -class TerminationHandler(DefaultInterventionHandler): +class TerminationHandler: def __init__(self): self.terminateMessage: TerminateMessage | None = None - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: + async def __call__(self, message: Any, message_context: MessageContext) -> Any: if isinstance(message, TerminateMessage): self.terminateMessage = message return message diff --git a/python/packages/autogen-core/src/autogen_core/__init__.py b/python/packages/autogen-core/src/autogen_core/__init__.py index 0f085d29bdfe..ffd85d16a18d 100644 --- a/python/packages/autogen-core/src/autogen_core/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/__init__.py @@ -24,6 +24,7 @@ from ._default_subscription import DefaultSubscription, default_subscription, type_subscription from ._default_topic import DefaultTopicId from ._image import Image +from ._intervention import DropMessage, InterventionFunction from ._message_context import MessageContext from ._message_handler_context import MessageHandlerContext from ._routed_agent import RoutedAgent, event, message_handler, rpc @@ -99,4 +100,6 @@ "ROOT_LOGGER_NAME", "EVENT_LOGGER_NAME", "TRACE_LOGGER_NAME", + "DropMessage", + "InterventionFunction", ] diff --git a/python/packages/autogen-core/src/autogen_core/_agent.py b/python/packages/autogen-core/src/autogen_core/_agent.py index edb5e59b1ce3..0202522d08a3 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_agent.py @@ -17,16 +17,13 @@ def id(self) -> AgentId: """ID of the agent.""" ... - async def on_message(self, message: Any, ctx: MessageContext) -> Any: + async def on_message(self, message: Any, ctx: MessageContext) -> None: """Message handler for the agent. This should only be called by the runtime, not by other agents. Args: message (Any): Received message. Type is one of the types in `subscriptions`. ctx (MessageContext): Context of the message. - Returns: - Any: Response to the message. Can be None. - Raises: asyncio.CancelledError: If the message was cancelled. CantHandleException: If the agent cannot handle the message. diff --git a/python/packages/autogen-core/src/autogen_core/_agent_proxy.py b/python/packages/autogen-core/src/autogen_core/_agent_proxy.py index f3eb70f28270..5b8ddc314a0c 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_proxy.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_proxy.py @@ -29,13 +29,11 @@ async def send_message( self, message: Any, *, - sender: AgentId, cancellation_token: CancellationToken | None = None, ) -> Any: return await self._runtime.send_message( message, recipient=self._agent, - sender=sender, cancellation_token=cancellation_token, ) diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 8156d378275a..9e55a94de264 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -26,7 +26,6 @@ async def send_message( message: Any, recipient: AgentId, *, - sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Any: """Send a message to an agent and get a response. diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 79bffd36d62b..b5cf4a94a924 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -1,13 +1,25 @@ from __future__ import annotations +import asyncio import inspect +import uuid import warnings from abc import ABC, abstractmethod +from asyncio import Future from collections.abc import Sequence -from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final +from typing import Any, Awaitable, Callable, ClassVar, Dict, List, Mapping, Tuple, Type, TypeVar, final from typing_extensions import Self +from autogen_core._types import ( + CancelledRpc, + CancelRpc, + CantHandleMessageResponse, + RpcMessageDroppedResponse, + RpcNoneResponse, +) +from autogen_core.exceptions import CantHandleException + from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext @@ -21,6 +33,15 @@ from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId from ._type_prefix_subscription import TypePrefixSubscription +from ._well_known_topics import ( + format_error_topic, + format_rpc_request_topic, + format_rpc_response_topic, + is_error_message, + is_rpc_cancel, + is_rpc_request, + is_rpc_response, +) T = TypeVar("T", bound=Agent) @@ -81,7 +102,17 @@ def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) - def __init__(self, description: str) -> None: + def __init__(self, description: str, *, forward_unbound_rpc_responses_to_handler: bool = False) -> None: + """Base agent that all agents should inherit from. Puts in place assumed common functionality. + + Args: + description (str): Description of the agent. + forward_unbound_rpc_responses_to_handler (bool, optional): If an rpc request ID is not know to the agent, should the rpc request be forwarded to the handler. Defaults to False. + + Raises: + RuntimeError: If the agent is not instantiated within the context of an AgentRuntime. + ValueError: If there is an argument type error. + """ try: runtime = AgentInstantiationContext.current_runtime() id = AgentInstantiationContext.current_agent_id() @@ -95,6 +126,15 @@ def __init__(self, description: str) -> None: if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description + self._pending_rpc_requests: Dict[str, Future[Any]] = {} + self._self_rpc_handlers_in_progress: Dict[str, Future[Any]] = {} + + # TODO: find a way to clean this up over time. + # Essentially, the reason for this existing is if a response is sent but we get an error back for that response + # We need to forward this error back to the original sender, so they can fail their RPC. + # Map of request_id -> (rpc_request_message_id, agent_type_of_rpc_sender) + self._sent_rpc_responses: Dict[str, Tuple[str, str]] = {} + self._forward_unbound_rpc_responses_to_handler = forward_unbound_rpc_responses_to_handler @property def type(self) -> str: @@ -108,12 +148,71 @@ def id(self) -> AgentId: def runtime(self) -> AgentRuntime: return self._runtime + @abstractmethod + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... + @final - async def on_message(self, message: Any, ctx: MessageContext) -> Any: - return await self.on_message_impl(message, ctx) + async def on_message(self, message: Any, ctx: MessageContext) -> None: + # Intercept errors for outstanding rpc requests, let the others pass through + if (request_id := is_error_message(ctx.topic_id.type)) is not None: + # Check if this error corresponds to an RPC response we have sent + if request_id in self._sent_rpc_responses: + # The recipient we were trying to send a response to never got this response, so we're going to send an error to them instead of the original message + # If this message gets dropped, we're just going to ignore things + original_rpc_request_message_id, agent_type_of_rpc_sender = self._sent_rpc_responses[request_id] + error_topic = format_error_topic( + error_recipient_agent_type=agent_type_of_rpc_sender, request_id=original_rpc_request_message_id + ) + await self.publish_message( + RpcMessageDroppedResponse(original_rpc_request_message_id), TopicId(error_topic, self.id.key) + ) + # Check if we have a pending RPC that is error corresponds to + elif request_id in self._pending_rpc_requests: + self._pending_rpc_requests[request_id].set_exception(message) + del self._pending_rpc_requests[request_id] + else: + await self.on_message_impl(message, ctx) + + return None + + # Intercept RPC cancel + if (request_id := is_rpc_cancel(ctx.topic_id.type)) is not None: + if request_id in self._self_rpc_handlers_in_progress: + if isinstance(message, CancelRpc): + self._self_rpc_handlers_in_progress[request_id].cancel() + del self._self_rpc_handlers_in_progress[request_id] + + return None + + # Intercept RPC responses + if (request_id := is_rpc_response(ctx.topic_id.type)) is not None: + if request_id in self._pending_rpc_requests: + if isinstance(message, RpcNoneResponse): + message = None + if isinstance(message, CancelledRpc): + self._pending_rpc_requests[request_id].cancel() + self._pending_rpc_requests[request_id].set_result(message) + del self._pending_rpc_requests[request_id] + elif self._forward_unbound_rpc_responses_to_handler: + await self.on_message_impl(message, ctx) + else: + warnings.warn( + f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", + stacklevel=2, + ) + return None - @abstractmethod - async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: ... + try: + await self.on_message_impl(message, ctx) + # If the agent signalled it cannot handle this message, and it was an RPC request. Let's deliver this error to the RPC sender so they know. + except CantHandleException: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + error_topic = format_error_topic(error_recipient_agent_type=requestor_type, request_id=ctx.message_id) + await self.publish_message( + CantHandleMessageResponse(message_id=ctx.message_id), TopicId(error_topic, self.id.key) + ) + else: + raise async def send_message( self, @@ -121,26 +220,66 @@ async def send_message( recipient: AgentId, *, cancellation_token: CancellationToken | None = None, + timeout: float | None = None, ) -> Any: """See :py:meth:`autogen_core.AgentRuntime.send_message` for more information.""" if cancellation_token is None: cancellation_token = CancellationToken() - return await self._runtime.send_message( + recipient_topic = TopicId( + type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), + source=recipient.key, + ) + request_id = str(uuid.uuid4()) + + future = Future[Any]() + + await self._runtime.publish_message( message, sender=self.id, - recipient=recipient, + topic_id=recipient_topic, cancellation_token=cancellation_token, + message_id=request_id, ) + self._pending_rpc_requests[request_id] = future + + async with asyncio.timeout(timeout): + return await future + + async def _rpc_response(self, handler_return_value: Any, ctx: MessageContext) -> None: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + if handler_return_value is None: + handler_return_value = RpcNoneResponse() + + response_topic_id = TopicId( + type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), + source=self.id.key, + ) + message_id = str(uuid.uuid4()) + # Intentionally accessing a private attribute here + # We store this so that if the response is dropped, we can send an error to the client instead. + # request_id -> (rpc_request_message_id, agent_type_of_rpc_sender) + self._sent_rpc_responses[message_id] = (ctx.message_id, requestor_type) # type: ignore + + await self.publish_message( + message=handler_return_value, + topic_id=response_topic_id, + cancellation_token=ctx.cancellation_token, + message_id=message_id, + ) + async def publish_message( self, message: Any, topic_id: TopicId, *, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> None: - await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token) + await self._runtime.publish_message( + message, topic_id, sender=self.id, cancellation_token=cancellation_token, message_id=message_id + ) async def save_state(self) -> Mapping[str, Any]: warnings.warn("save_state not implemented", stacklevel=2) diff --git a/python/packages/autogen-core/src/autogen_core/_cancellation_token.py b/python/packages/autogen-core/src/autogen_core/_cancellation_token.py index 5aa44903963f..a4a089b3113c 100644 --- a/python/packages/autogen-core/src/autogen_core/_cancellation_token.py +++ b/python/packages/autogen-core/src/autogen_core/_cancellation_token.py @@ -1,26 +1,29 @@ +import inspect import threading from asyncio import Future -from typing import Any, Callable, List +from typing import Any, Awaitable, Callable, List class CancellationToken: def __init__(self) -> None: self._cancelled: bool = False self._lock: threading.Lock = threading.Lock() - self._callbacks: List[Callable[[], None]] = [] + self._callbacks: List[Callable[[], None] | Callable[[], Awaitable[None]]] = [] - def cancel(self) -> None: + async def cancel(self) -> None: with self._lock: if not self._cancelled: self._cancelled = True for callback in self._callbacks: - callback() + res = callback() + if inspect.isawaitable(res): + await res def is_cancelled(self) -> bool: with self._lock: return self._cancelled - def add_callback(self, callback: Callable[[], None]) -> None: + def add_callback(self, callback: Callable[[], None] | Callable[[], Awaitable[None]]) -> None: with self._lock: if self._cancelled: callback() diff --git a/python/packages/autogen-core/src/autogen_core/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/_closure_agent.py index 03206d18feff..66a990a29204 100644 --- a/python/packages/autogen-core/src/autogen_core/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_closure_agent.py @@ -15,7 +15,7 @@ from ._subscription import Subscription from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId -from ._type_helpers import get_types +from ._type_helpers import AnyType, get_types from .exceptions import CantHandleException T = TypeVar("T") @@ -73,7 +73,11 @@ async def publish_message( class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]] + self, + description: str, + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], + *, + forward_unbound_rpc_responses_to_handler: bool = False, ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -89,7 +93,7 @@ def __init__( handled_types = get_handled_types_from_closure(closure) self._expected_types = handled_types self._closure = closure - super().__init__(description) + super().__init__(description, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) @property def metadata(self) -> AgentMetadata: @@ -109,7 +113,7 @@ def runtime(self) -> AgentRuntime: return self._runtime async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: - if type(message) not in self._expected_types: + if AnyType not in self._expected_types and type(message) not in self._expected_types: raise CantHandleException( f"Message type {type(message)} not in target types {self._expected_types} of {self.id}" ) @@ -130,19 +134,23 @@ async def register_closure( type: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, - skip_class_subscriptions: bool = False, skip_direct_message_subscription: bool = False, + forward_unbound_rpc_responses_to_handler: bool = False, description: str = "", subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: def factory() -> ClosureAgent: - return ClosureAgent(description=description, closure=closure) + return ClosureAgent( + description=description, + closure=closure, + forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler, + ) agent_type = await cls.register( runtime=runtime, type=type, factory=factory, # type: ignore - skip_class_subscriptions=skip_class_subscriptions, + skip_class_subscriptions=True, skip_direct_message_subscription=skip_direct_message_subscription, ) diff --git a/python/packages/autogen-core/src/autogen_core/_intervention.py b/python/packages/autogen-core/src/autogen_core/_intervention.py new file mode 100644 index 000000000000..e13359747188 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/_intervention.py @@ -0,0 +1,17 @@ +from typing import Any, Awaitable, Callable, final + +from autogen_core._message_context import MessageContext + +__all__ = [ + "DropMessage", + "InterventionFunction", +] + + +@final +class DropMessage: ... + + +InterventionFunction = Callable[ + [Any, MessageContext], Any | Awaitable[Any] | type[DropMessage] | Awaitable[type[DropMessage]] +] diff --git a/python/packages/autogen-core/src/autogen_core/_message_context.py b/python/packages/autogen-core/src/autogen_core/_message_context.py index c5c00559ed0e..65cbbb64d4bb 100644 --- a/python/packages/autogen-core/src/autogen_core/_message_context.py +++ b/python/packages/autogen-core/src/autogen_core/_message_context.py @@ -8,7 +8,6 @@ @dataclass class MessageContext: sender: AgentId | None - topic_id: TopicId | None - is_rpc: bool + topic_id: TopicId cancellation_token: CancellationToken message_id: str diff --git a/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py b/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py new file mode 100644 index 000000000000..6d8171dae52b --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py @@ -0,0 +1,101 @@ +import asyncio +import uuid +import warnings +from typing import Any + +from autogen_core._types import CancelledRpc, CancelRpc, CantHandleMessageResponse, RpcMessageDroppedResponse +from autogen_core.exceptions import CantHandleException, MessageDroppedException + +from ._agent_id import AgentId +from ._agent_runtime import AgentRuntime +from ._cancellation_token import CancellationToken +from ._closure_agent import ClosureAgent, ClosureContext +from ._message_context import MessageContext +from ._topic import TopicId +from ._well_known_topics import ( + format_error_topic, + format_rpc_cancel_topic, + format_rpc_request_topic, + format_rpc_response_topic, +) + + +class PublishBasedRpcMixin(AgentRuntime): + async def send_message( + self: AgentRuntime, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + timeout: float | None = None, + ) -> Any: + if cancellation_token is None: + cancellation_token = CancellationToken() + + rpc_request_id = str(uuid.uuid4()) + # TODO add "-" to topic and agent type allowed characters in spec + closure_agent_type = f"rpc_receiver_{recipient.type}_{rpc_request_id}" + + future: asyncio.Future[Any] = asyncio.Future() + expected_response_topic_type = format_rpc_response_topic( + rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id + ) + expected_error_topic_type = format_error_topic(closure_agent_type, request_id=rpc_request_id) + + async def set_result(closure_context: ClosureContext, message: Any, ctx: MessageContext) -> None: + assert ctx.topic_id is not None + if ctx.topic_id.type == expected_response_topic_type: + if isinstance(message, CancelledRpc): + future.cancel() + else: + future.set_result(message) + elif ctx.topic_id.type == expected_error_topic_type: + # Well known things we handle - dropped message, cant handle + # If the message is for a dropped message + if isinstance(message, CantHandleMessageResponse): + future.set_exception(CantHandleException()) + if isinstance(message, RpcMessageDroppedResponse): + future.set_exception(MessageDroppedException()) + else: + warnings.warn( + f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}.", + stacklevel=2, + ) + else: + warnings.warn( + f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", + stacklevel=2, + ) + + # TODO: remove agent after response is received + + await ClosureAgent.register_closure( + runtime=self, + type=closure_agent_type, + closure=set_result, + forward_unbound_rpc_responses_to_handler=True, + ) + + rpc_request_topic_id = format_rpc_request_topic( + rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type + ) + await self.publish_message( + message=message, + topic_id=TopicId(type=rpc_request_topic_id, source=recipient.key), + message_id=rpc_request_id, + sender=AgentId(type=closure_agent_type, key=recipient.key), + ) + + async def send_cancel(): + cancel_topic = format_rpc_cancel_topic(rpc_recipient_agent_type=recipient.type, request_id=rpc_request_id) + await self.publish_message( + message=CancelRpc(), + topic_id=TopicId(cancel_topic, recipient.key), + ) + + cancellation_token.add_callback(send_cancel) + + async with asyncio.timeout(timeout): + return await future + + # register a closure agent... diff --git a/python/packages/autogen-core/src/autogen_core/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/_routed_agent.py index a5908278cab9..a7cfd5ed55a0 100644 --- a/python/packages/autogen-core/src/autogen_core/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_routed_agent.py @@ -1,4 +1,7 @@ +import asyncio import logging +import warnings +from asyncio import CancelledError from functools import wraps from typing import ( Any, @@ -18,15 +21,18 @@ runtime_checkable, ) +from autogen_core._types import CancelledRpc + from ._base_agent import BaseAgent from ._message_context import MessageContext from ._serialization import MessageSerializer, try_get_known_serializers_for_type from ._type_helpers import AnyType, get_types +from ._well_known_topics import is_rpc_request from .exceptions import CantHandleException logger = logging.getLogger("autogen_core") -AgentT = TypeVar("AgentT") +AgentT = TypeVar("AgentT", bound=BaseAgent) ReceivesT = TypeVar("ReceivesT") ProducesT = TypeVar("ProducesT", covariant=True) @@ -139,7 +145,7 @@ def decorator( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") @@ -154,7 +160,16 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> Prod else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - return return_value + # Dont return, but publish it if you need to... + # Any return is treated as a response to the RPC request and is published accordingly + + if return_value is not None and is_rpc_request(ctx.topic_id.type) is None: + warnings.warn( + "Returning a value from a message handler that is not an RPC request. This value will be ignored.", + stacklevel=2, + ) + else: + await self._rpc_response(return_value, ctx) # type: ignore wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) @@ -279,8 +294,10 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - # Wrap the match function with a check on the is_rpc flag. - wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True) + # Wrap the match function with a check on the topic for rpc + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and ( + match(_message, _ctx) if match else True + ) return wrapper_handler @@ -379,14 +396,25 @@ def decorator( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") else: logger.warning(f"Message type {type(message)} not in target types {target_types}") - return_value = await func(self, message, ctx) + # Should be an rpc request, as the match function should have filtered it + assert is_rpc_request(ctx.topic_id.type) is not None + + try: + future = asyncio.ensure_future(func(self, message, ctx)) + self._self_rpc_handlers_in_progress[ctx.message_id] = future # type: ignore + return_value = await future + except CancelledError: + await self._rpc_response(CancelledRpc(), ctx) # type: ignore + return + finally: + del self._self_rpc_handlers_in_progress[ctx.message_id] # type: ignore if AnyType not in return_types and type(return_value) not in return_types: if strict: @@ -394,13 +422,17 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> Prod else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - return return_value + # Dont return, but publish + # Any return is treated as a response to the RPC request and is published accordingly + await self._rpc_response(return_value, ctx) # type: ignore wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and ( + match(_message, _ctx) if match else True + ) return wrapper_handler @@ -471,7 +503,7 @@ def __init__(self, description: str) -> None: super().__init__(description) - async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: """Handle a message by routing it to the appropriate message handler. Do not override this method in subclasses. Instead, add message handlers as methods decorated with either the :func:`event` or :func:`rpc` decorator.""" @@ -482,8 +514,13 @@ async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None # Call the first handler whose router returns True and then return the result. for h in handlers: if h.router(message, ctx): - return await h(self, message, ctx) - return await self.on_unhandled_message(message, ctx) # type: ignore + await h(self, message, ctx) + return + + if is_rpc_request(ctx.topic_id.type): + raise CantHandleException(f"No RPC handler found for message type {key_type}") + + await self.on_unhandled_message(message, ctx) async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: """Called when a message is received that does not have a matching message handler. diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 4feab3051dda..b46fbacdcbb0 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -6,7 +6,7 @@ import threading import uuid import warnings -from asyncio import CancelledError, Future, Task +from asyncio import CancelledError, Task from collections.abc import Sequence from dataclasses import dataclass from enum import Enum @@ -15,7 +15,7 @@ from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from autogen_core._serialization import MessageSerializer, SerializationRegistry +from autogen_core._types import RpcMessageDroppedResponse from ._agent import Agent from ._agent_id import AgentId @@ -24,15 +24,17 @@ from ._agent_runtime import AgentRuntime from ._agent_type import AgentType from ._cancellation_token import CancellationToken +from ._intervention import DropMessage, InterventionFunction from ._message_context import MessageContext from ._message_handler_context import MessageHandlerContext +from ._publish_based_rpc import PublishBasedRpcMixin from ._runtime_impl_helpers import SubscriptionManager, get_impl +from ._serialization import MessageSerializer, SerializationRegistry from ._subscription import Subscription from ._subscription_context import SubscriptionInstantiationContext from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata from ._topic import TopicId -from .base.intervention import DropMessage, InterventionHandler -from .exceptions import MessageDroppedException +from ._well_known_topics import format_error_topic logger = logging.getLogger("autogen_core") event_logger = logging.getLogger("autogen_core.events") @@ -55,30 +57,6 @@ class PublishMessageEnvelope: message_id: str -@dataclass(kw_only=True) -class SendMessageEnvelope: - """A message envelope for sending a message to a specific agent that can handle - the message of the type T.""" - - message: Any - sender: AgentId | None - recipient: AgentId - future: Future[Any] - cancellation_token: CancellationToken - metadata: EnvelopeMetadata | None = None - - -@dataclass(kw_only=True) -class ResponseMessageEnvelope: - """A message envelope for sending a response to a message.""" - - message: Any - future: Future[Any] - sender: AgentId - recipient: AgentId | None - metadata: EnvelopeMetadata | None = None - - P = ParamSpec("P") T = TypeVar("T", bound=Agent) @@ -164,15 +142,15 @@ def _warn_if_none(value: Any, handler_name: str) -> None: ) -class SingleThreadedAgentRuntime(AgentRuntime): +class SingleThreadedAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, *, - intervention_handlers: List[InterventionHandler] | None = None, + intervention_handlers: List[InterventionFunction] | None = None, tracer_provider: TracerProvider | None = None, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) - self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] + self._message_queue: List[PublishMessageEnvelope] = [] # (namespace, type) -> List[AgentId] self._agent_factories: Dict[ str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] @@ -188,7 +166,7 @@ def __init__( @property def unprocessed_messages( self, - ) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: + ) -> Sequence[PublishMessageEnvelope]: return self._message_queue @property @@ -199,56 +177,6 @@ def outstanding_tasks(self) -> int: def _known_agent_names(self) -> Set[str]: return set(self._agent_factories.keys()) - # Returns the response of the message - async def send_message( - self, - message: Any, - recipient: AgentId, - *, - sender: AgentId | None = None, - cancellation_token: CancellationToken | None = None, - ) -> Any: - if cancellation_token is None: - cancellation_token = CancellationToken() - - # event_logger.info( - # MessageEvent( - # payload=message, - # sender=sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.SEND, - # ) - # ) - - with self._tracer_helper.trace_block( - "create", - recipient, - parent=None, - extraAttributes={"message_type": type(message).__name__}, - ): - future = asyncio.get_event_loop().create_future() - if recipient.type not in self._known_agent_names: - future.set_exception(Exception("Recipient not found")) - - content = message.__dict__ if hasattr(message, "__dict__") else message - logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") - - self._message_queue.append( - SendMessageEnvelope( - message=message, - recipient=recipient, - future=future, - cancellation_token=cancellation_token, - sender=sender, - metadata=get_telemetry_envelope_metadata(), - ) - ) - - cancellation_token.link_future(future) - - return await future - async def publish_message( self, message: Any, @@ -305,62 +233,6 @@ async def load_state(self, state: Mapping[str, Any]) -> None: if agent_id.type in self._known_agent_names: await (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) - async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: - with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): - recipient = message_envelope.recipient - # todo: check if recipient is in the known namespaces - # assert recipient in self._agents - - try: - # TODO use id - sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown" - logger.info( - f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}" - ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - recipient_agent = await self._get_agent(recipient) - message_context = MessageContext( - sender=message_envelope.sender, - topic_id=None, - is_rpc=True, - cancellation_token=message_envelope.cancellation_token, - # Will be fixed when send API removed - message_id="NOT_DEFINED_TODO_FIX", - ) - with MessageHandlerContext.populate_context(recipient_agent.id): - response = await recipient_agent.on_message( - message_envelope.message, - ctx=message_context, - ) - except CancelledError as e: - if not message_envelope.future.cancelled(): - message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() - return - except BaseException as e: - message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() - return - - self._message_queue.append( - ResponseMessageEnvelope( - message=response, - future=message_envelope.future, - sender=message_envelope.recipient, - recipient=message_envelope.sender, - metadata=get_telemetry_envelope_metadata(), - ) - ) - self._outstanding_tasks.decrement() - async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata): try: @@ -390,7 +262,6 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No message_context = MessageContext( sender=message_envelope.sender, topic_id=message_envelope.topic_id, - is_rpc=False, cancellation_token=message_envelope.cancellation_token, message_id=message_envelope.message_id, ) @@ -417,28 +288,15 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any: self._outstanding_tasks.decrement() # TODO if responses are given for a publish - async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: - with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata): - content = ( - message_envelope.message.__dict__ - if hasattr(message_envelope.message, "__dict__") - else message_envelope.message - ) - logger.info( - f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}" - ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=message_envelope.recipient, - # kind=MessageKind.RESPOND, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - self._outstanding_tasks.decrement() - if not message_envelope.future.cancelled(): - message_envelope.future.set_result(message_envelope.message) + async def _send_error(self, exception: Any, for_message_id: str, recipient: AgentId) -> None: + topic = format_error_topic(recipient.type, for_message_id) + + # Errors don't have an originating sender + await self.publish_message( + message=exception, + topic_id=TopicId(topic, recipient.key), + sender=None, + ) async def process_next(self) -> None: """Process the next message in the queue.""" @@ -449,71 +307,48 @@ async def process_next(self) -> None: return message_envelope = self._message_queue.pop(0) - match message_envelope: - case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - with self._tracer_helper.trace_block( - "intercept", handler.__class__.__name__, parent=message_envelope.metadata - ): - try: - temp_message = await handler.on_send(message, sender=sender, recipient=recipient) - _warn_if_none(temp_message, "on_send") - except BaseException as e: - future.set_exception(e) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) - return - - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_send(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - case PublishMessageEnvelope( - message=message, - sender=sender, - ): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - with self._tracer_helper.trace_block( - "intercept", handler.__class__.__name__, parent=message_envelope.metadata - ): - try: - temp_message = await handler.on_publish(message, sender=sender) - _warn_if_none(temp_message, "on_publish") - except BaseException as e: - # TODO: we should raise the intervention exception to the publisher. - logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - # TODO log message dropped - return - - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_publish(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - try: - temp_message = await handler.on_response(message, sender=sender, recipient=recipient) - _warn_if_none(temp_message, "on_response") - except BaseException as e: - # TODO: should we raise the exception to sender of the response instead? - future.set_exception(e) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) - return - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_response(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + message = message_envelope.message + + if self._intervention_handlers is not None: + message_context = MessageContext( + sender=message_envelope.sender, + topic_id=message_envelope.topic_id, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + for handler in self._intervention_handlers: + with self._tracer_helper.trace_block( + "intercept", handler.__class__.__name__, parent=message_envelope.metadata + ): + try: + temp_message = handler(message, message_context) + if inspect.isawaitable(temp_message): + temp_message = await temp_message + + _warn_if_none(temp_message, "on_publish") + except BaseException as e: + logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) + return + + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + # TODO log message dropped + # Send message dropped to sender + + # If it's None, then we don't know who to send the message to + if message_envelope.sender is not None: + await self._send_error( + RpcMessageDroppedResponse(message_id=message_envelope.message_id), + message_envelope.message_id, + message_envelope.sender, + ) + + return + + message_envelope.message = temp_message + self._outstanding_tasks.increment() + task = asyncio.create_task(self._process_publish(message_envelope)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) diff --git a/python/packages/autogen-core/src/autogen_core/_subscription_context.py b/python/packages/autogen-core/src/autogen_core/_subscription_context.py index 1cfd3fd882ed..29b1e1629798 100644 --- a/python/packages/autogen-core/src/autogen_core/_subscription_context.py +++ b/python/packages/autogen-core/src/autogen_core/_subscription_context.py @@ -2,7 +2,7 @@ from contextvars import ContextVar from typing import Any, ClassVar, Generator -from autogen_core._agent_type import AgentType +from ._agent_type import AgentType class SubscriptionInstantiationContext: diff --git a/python/packages/autogen-core/src/autogen_core/_types.py b/python/packages/autogen-core/src/autogen_core/_types.py index 5e3850ffae8b..d261909b98e6 100644 --- a/python/packages/autogen-core/src/autogen_core/_types.py +++ b/python/packages/autogen-core/src/autogen_core/_types.py @@ -10,3 +10,29 @@ class FunctionCall: arguments: str # Function to call name: str + + +# TODO: Make this xlang friendly +@dataclass +class RpcNoneResponse: + pass + + +@dataclass +class RpcMessageDroppedResponse: + message_id: str + + +@dataclass +class CantHandleMessageResponse: + message_id: str + + +@dataclass +class CancelRpc: + pass + + +@dataclass +class CancelledRpc: + pass diff --git a/python/packages/autogen-core/src/autogen_core/_well_known_topics.py b/python/packages/autogen-core/src/autogen_core/_well_known_topics.py new file mode 100644 index 000000000000..e6f17a8d88fc --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/_well_known_topics.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import Optional + + +def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_type: str) -> str: + return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}" + + +def format_rpc_cancel_topic(rpc_recipient_agent_type: str, request_id: str) -> str: + return f"{rpc_recipient_agent_type}:rpc_cancel={request_id}" + + +def format_rpc_response_topic(rpc_sender_agent_type: str, request_id: str) -> str: + return f"{rpc_sender_agent_type}:rpc_response={request_id}" + + +# If is an rpc response, return the request id +def is_rpc_response(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_response= + for segment in topic_segments: + if segment.startswith("rpc_response="): + return segment[len("rpc_response=") :] + return None + + +# If is an rpc response, return the request id +def is_rpc_cancel(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_cancel= + for segment in topic_segments: + if segment.startswith("rpc_cancel="): + return segment[len("rpc_cancel=") :] + return None + + +# If is an rpc response, return the requestor agent type +def is_rpc_request(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_request= + for segment in topic_segments: + if segment.startswith("rpc_request="): + return segment[len("rpc_request=") :] + return None + + +# {AgentType}:error={RequestId} - error message that corresponds to a request +def is_error_message(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_response= + for segment in topic_segments: + if segment.startswith("error="): + return segment[len("error=") :] + return None + + +def format_error_topic(error_recipient_agent_type: str, request_id: str) -> str: + return f"{error_recipient_agent_type}:error={request_id}" diff --git a/python/packages/autogen-core/src/autogen_core/base/intervention.py b/python/packages/autogen-core/src/autogen_core/base/intervention.py index 5fe337b8776d..a6356a010091 100644 --- a/python/packages/autogen-core/src/autogen_core/base/intervention.py +++ b/python/packages/autogen-core/src/autogen_core/base/intervention.py @@ -1,4 +1,4 @@ -from typing import Any, Awaitable, Callable, Protocol, final +from typing import Any, Protocol from .._agent_id import AgentId @@ -9,12 +9,7 @@ "DefaultInterventionHandler", ] - -@final -class DropMessage: ... - - -InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]] +from .._intervention import DropMessage, InterventionFunction class InterventionHandler(Protocol): diff --git a/python/packages/autogen-core/src/autogen_core/logging.py b/python/packages/autogen-core/src/autogen_core/logging.py index 5e3870203e57..34a3012fe423 100644 --- a/python/packages/autogen-core/src/autogen_core/logging.py +++ b/python/packages/autogen-core/src/autogen_core/logging.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, cast -from autogen_core import AgentId +from ._agent_id import AgentId class LLMCallEvent: diff --git a/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py b/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py index 08d8f4b25376..3792d287281b 100644 --- a/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py +++ b/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from typing import List -from .. import FunctionCall, MessageContext, RoutedAgent, message_handler +from .. import FunctionCall, MessageContext, RoutedAgent +from .._routed_agent import rpc from ..models import FunctionExecutionResult from ..tools import Tool @@ -16,7 +17,7 @@ @dataclass -class ToolException(BaseException): +class ToolException: call_id: str content: str @@ -58,8 +59,10 @@ def __init__( def tools(self) -> List[Tool]: return self._tools - @message_handler - async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) -> FunctionExecutionResult: + @rpc + async def handle_function_call( + self, message: FunctionCall, ctx: MessageContext + ) -> FunctionExecutionResult | ToolNotFoundException | InvalidToolArgumentsException | ToolExecutionException: """Handles a `FunctionCall` message by executing the requested tool with the provided arguments. Args: @@ -76,16 +79,16 @@ async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) """ tool = next((tool for tool in self._tools if tool.name == message.name), None) if tool is None: - raise ToolNotFoundException(call_id=message.id, content=f"Error: Tool not found: {message.name}") + return ToolNotFoundException(call_id=message.id, content=f"Error: Tool not found: {message.name}") else: try: arguments = json.loads(message.arguments) result = await tool.run_json(args=arguments, cancellation_token=ctx.cancellation_token) result_as_str = tool.return_value_as_string(result) - except json.JSONDecodeError as e: - raise InvalidToolArgumentsException( + except json.JSONDecodeError: + return InvalidToolArgumentsException( call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}" - ) from e + ) except Exception as e: - raise ToolExecutionException(call_id=message.id, content=f"Error: {e}") from e + return ToolExecutionException(call_id=message.id, content=f"Error: {e}") return FunctionExecutionResult(content=result_as_str, call_id=message.id) diff --git a/python/packages/autogen-core/test.py b/python/packages/autogen-core/test.py new file mode 100644 index 000000000000..d8c5672d1feb --- /dev/null +++ b/python/packages/autogen-core/test.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + +from autogen_core.base import MessageContext +from autogen_core.base._agent_id import AgentId +from autogen_core.components import RoutedAgent +from autogen_core.components._routed_agent import rpc + +from autogen_core.application import SingleThreadedAgentRuntime +import asyncio + +@dataclass +class Message: + content: str + +class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My agent") + + @rpc + async def handle_message(self, message: Message, ctx: MessageContext) -> Message: + print(f"Received message: {message.content}") + return Message(content=f"I got: {message.content}") + +async def main(): + runtime = SingleThreadedAgentRuntime() + + await MyAgent.register(runtime, "my_agent", MyAgent) + + runtime.start() + print(await runtime.send_message( + Message("I'm sending you this"), recipient=AgentId("my_agent", "default") + )) + await runtime.stop_when_idle() + +asyncio.run(main()) diff --git a/python/packages/autogen-core/tests/test_cancellation.py b/python/packages/autogen-core/tests/test_cancellation.py index 34a5d7f962c4..4d803813cef2 100644 --- a/python/packages/autogen-core/tests/test_cancellation.py +++ b/python/packages/autogen-core/tests/test_cancellation.py @@ -9,8 +9,8 @@ MessageContext, RoutedAgent, SingleThreadedAgentRuntime, - message_handler, ) +from autogen_core._routed_agent import rpc @dataclass @@ -28,7 +28,7 @@ def __init__(self) -> None: self.called = False self.cancelled = False - @message_handler + @rpc async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) @@ -48,7 +48,7 @@ def __init__(self, nested_agent: AgentId) -> None: self.cancelled = False self._nested_agent = nested_agent - @message_handler + @rpc async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.called = True response = self.send_message(message, self._nested_agent, cancellation_token=ctx.cancellation_token) @@ -74,9 +74,9 @@ async def test_cancellation_with_token() -> None: while len(runtime.unprocessed_messages) == 0: await asyncio.sleep(0.01) - await runtime.process_next() + runtime.start() - token.cancel() + await token.cancel() with pytest.raises(asyncio.CancelledError): await response @@ -85,6 +85,7 @@ async def test_cancellation_with_token() -> None: long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled + await runtime.stop() @pytest.mark.asyncio @@ -107,8 +108,9 @@ async def test_nested_cancellation_only_outer_called() -> None: while len(runtime.unprocessed_messages) == 0: await asyncio.sleep(0.01) - await runtime.process_next() - token.cancel() + runtime.start() + + await token.cancel() with pytest.raises(asyncio.CancelledError): await response @@ -120,6 +122,7 @@ async def test_nested_cancellation_only_outer_called() -> None: long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent) assert long_running_agent.called is False assert long_running_agent.cancelled is False + await runtime.stop() @pytest.mark.asyncio @@ -143,10 +146,9 @@ async def test_nested_cancellation_inner_called() -> None: while len(runtime.unprocessed_messages) == 0: await asyncio.sleep(0.01) - await runtime.process_next() - # allow the inner agent to process - await runtime.process_next() - token.cancel() + runtime.start() + + await token.cancel() with pytest.raises(asyncio.CancelledError): await response @@ -158,3 +160,4 @@ async def test_nested_cancellation_inner_called() -> None: long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled + await runtime.stop() diff --git a/python/packages/autogen-core/tests/test_intervention.py b/python/packages/autogen-core/tests/test_intervention.py index a046201feff3..f2a58dc72904 100644 --- a/python/packages/autogen-core/tests/test_intervention.py +++ b/python/packages/autogen-core/tests/test_intervention.py @@ -1,17 +1,19 @@ +from typing import Any + import pytest -from autogen_core import AgentId, SingleThreadedAgentRuntime -from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage +from autogen_core import AgentId, DropMessage, MessageContext, SingleThreadedAgentRuntime +from autogen_core._well_known_topics import is_rpc_request, is_rpc_response from autogen_core.exceptions import MessageDroppedException from autogen_test_utils import LoopbackAgent, MessageType @pytest.mark.asyncio async def test_intervention_count_messages() -> None: - class DebugInterventionHandler(DefaultInterventionHandler): + class DebugInterventionHandler: def __init__(self) -> None: self.num_messages = 0 - async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType: + async def __call__(self, message: MessageType, message_context: MessageContext) -> MessageType: self.num_messages += 1 return message @@ -21,24 +23,23 @@ async def on_send(self, message: MessageType, *, sender: AgentId | None, recipie loopback = AgentId("name", key="default") runtime.start() - _response = await runtime.send_message(MessageType(), recipient=loopback) + _response = await runtime.send_message(MessageType(), recipient=loopback, timeout=120) await runtime.stop() - assert handler.num_messages == 1 + # 2 since request and response + assert handler.num_messages == 2 loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert loopback_agent.num_calls == 1 @pytest.mark.asyncio -async def test_intervention_drop_send() -> None: - class DropSendInterventionHandler(DefaultInterventionHandler): - async def on_send( - self, message: MessageType, *, sender: AgentId | None, recipient: AgentId - ) -> MessageType | type[DropMessage]: +async def test_intervention_drop_rpc_request() -> None: + async def handler(message: Any, message_context: MessageContext) -> Any | type[DropMessage]: + if is_rpc_request(message_context.topic_id.type): return DropMessage + return message - handler = DropSendInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) await LoopbackAgent.register(runtime, "name", LoopbackAgent) @@ -46,7 +47,7 @@ async def on_send( runtime.start() with pytest.raises(MessageDroppedException): - _response = await runtime.send_message(MessageType(), recipient=loopback) + _response = await runtime.send_message(MessageType(), recipient=loopback, timeout=120) await runtime.stop() @@ -55,74 +56,21 @@ async def on_send( @pytest.mark.asyncio -async def test_intervention_drop_response() -> None: - class DropResponseInterventionHandler(DefaultInterventionHandler): - async def on_response( - self, message: MessageType, *, sender: AgentId, recipient: AgentId | None - ) -> MessageType | type[DropMessage]: +async def test_intervention_drop_rpc_esponse() -> None: + async def handler(message: Any, message_context: MessageContext) -> Any | type[DropMessage]: + # Only drop the response and not the request! + if is_rpc_response(message_context.topic_id.type): return DropMessage - handler = DropResponseInterventionHandler() - runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - - await LoopbackAgent.register(runtime, "name", LoopbackAgent) - loopback = AgentId("name", key="default") - runtime.start() - - with pytest.raises(MessageDroppedException): - _response = await runtime.send_message(MessageType(), recipient=loopback) + return message - await runtime.stop() - - -@pytest.mark.asyncio -async def test_intervention_raise_exception_on_send() -> None: - class InterventionException(Exception): - pass - - class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_send( - self, message: MessageType, *, sender: AgentId | None, recipient: AgentId - ) -> MessageType | type[DropMessage]: # type: ignore - raise InterventionException - - handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() - with pytest.raises(InterventionException): - _response = await runtime.send_message(MessageType(), recipient=loopback) - - await runtime.stop() - - long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) - assert long_running_agent.num_calls == 0 - - -@pytest.mark.asyncio -async def test_intervention_raise_exception_on_respond() -> None: - class InterventionException(Exception): - pass - - class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_response( - self, message: MessageType, *, sender: AgentId, recipient: AgentId | None - ) -> MessageType | type[DropMessage]: # type: ignore - raise InterventionException - - handler = ExceptionInterventionHandler() - runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - - await LoopbackAgent.register(runtime, "name", LoopbackAgent) - loopback = AgentId("name", key="default") - runtime.start() - with pytest.raises(InterventionException): - _response = await runtime.send_message(MessageType(), recipient=loopback) + with pytest.raises(MessageDroppedException): + _response = await runtime.send_message(MessageType(), recipient=loopback, timeout=120) await runtime.stop() - - long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) - assert long_running_agent.num_calls == 1 diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index 440c839fa551..ec30b575c078 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -14,6 +14,8 @@ message_handler, rpc, ) +from autogen_core._well_known_topics import is_rpc_request +from autogen_core.exceptions import CantHandleException from autogen_test_utils import LoopbackAgent @@ -31,12 +33,12 @@ def __init__(self) -> None: self.num_calls_rpc = 0 self.num_calls_broadcast = 0 - @message_handler(match=lambda _, ctx: ctx.is_rpc) + @message_handler(match=lambda _, ctx: is_rpc_request(ctx.topic_id.type) is not None) async def on_rpc_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.num_calls_rpc += 1 return message - @message_handler(match=lambda _, ctx: not ctx.is_rpc) + @message_handler(match=lambda _, ctx: is_rpc_request(ctx.topic_id.type) is None) async def on_broadcast_message(self, message: MessageType, ctx: MessageContext) -> None: self.num_calls_broadcast += 1 @@ -70,7 +72,7 @@ async def test_message_handler_router() -> None: # Send an RPC message. runtime.start() - await runtime.send_message(MessageType(), recipient=agent_id) + await runtime.send_message(MessageType(), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=CounterAgent) assert agent.num_calls_broadcast == 1 @@ -113,14 +115,14 @@ async def test_routed_agent_message_matching() -> None: assert agent.handler_two_called is False runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(TestMessage("one"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) assert agent.handler_one_called is True assert agent.handler_two_called is False runtime.start() - await runtime.send_message(TestMessage("two"), recipient=agent_id) + await runtime.send_message(TestMessage("two"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) assert agent.handler_one_called is True @@ -166,7 +168,8 @@ async def test_event() -> None: # Send an RPC message, expect no change. runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + with pytest.raises(CantHandleException): + await runtime.send_message(TestMessage("one"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) assert agent.num_calls[0] == 1 @@ -198,7 +201,7 @@ async def test_rpc() -> None: # Send an RPC message. runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(TestMessage("one"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 @@ -206,7 +209,7 @@ async def test_rpc() -> None: # Send another RPC message. runtime.start() - await runtime.send_message(TestMessage("two"), recipient=agent_id) + await runtime.send_message(TestMessage("two"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 diff --git a/python/packages/autogen-core/tests/test_tool_agent.py b/python/packages/autogen-core/tests/test_tool_agent.py index d0d6dec8915b..98ef8e7e30c5 100644 --- a/python/packages/autogen-core/tests/test_tool_agent.py +++ b/python/packages/autogen-core/tests/test_tool_agent.py @@ -63,23 +63,27 @@ async def test_tool_agent() -> None: assert result == FunctionExecutionResult(call_id="1", content="pass") # Test raise function - with pytest.raises(ToolExecutionException): - await runtime.send_message(FunctionCall(id="2", arguments=json.dumps({"input": "raise"}), name="raise"), agent) + response = await runtime.send_message( + FunctionCall(id="2", arguments=json.dumps({"input": "raise"}), name="raise"), agent + ) + assert isinstance(response, ToolExecutionException) # Test invalid tool name - with pytest.raises(ToolNotFoundException): - await runtime.send_message(FunctionCall(id="3", arguments=json.dumps({"input": "pass"}), name="invalid"), agent) + response = await runtime.send_message( + FunctionCall(id="3", arguments=json.dumps({"input": "pass"}), name="invalid"), agent + ) + assert isinstance(response, ToolNotFoundException) # Test invalid arguments - with pytest.raises(InvalidToolArgumentsException): - await runtime.send_message(FunctionCall(id="3", arguments="invalid json /xd", name="pass"), agent) + response = await runtime.send_message(FunctionCall(id="3", arguments="invalid json /xd", name="pass"), agent) + assert isinstance(response, InvalidToolArgumentsException) # Test sleep and cancel. token = CancellationToken() result_future = runtime.send_message( FunctionCall(id="3", arguments=json.dumps({"input": "sleep"}), name="sleep"), agent, cancellation_token=token ) - token.cancel() + await token.cancel() with pytest.raises(asyncio.CancelledError): await result_future diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index d2bf41ce0d1f..4a66e0872b84 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -52,6 +52,7 @@ SerializationRegistry, ) from autogen_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata +from autogen_core._publish_based_rpc import PublishBasedRpcMixin from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated @@ -179,7 +180,7 @@ async def recv(self) -> agent_worker_pb2.Message: return await self._recv_queue.get() -class GrpcWorkerAgentRuntime(AgentRuntime): +class GrpcWorkerAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, host_address: str, @@ -239,16 +240,6 @@ async def _run_read_loop(self) -> None: match oneofcase: case "registerAgentTypeRequest" | "addSubscriptionRequest": logger.warning(f"Cant handle {oneofcase}, skipping.") - case "request": - task = asyncio.create_task(self._process_request(message.request)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - task = asyncio.create_task(self._process_response(message.response)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore @@ -333,51 +324,6 @@ async def _send_message( with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata): await self._host_connection.send(runtime_message) - async def send_message( - self, - message: Any, - recipient: AgentId, - *, - sender: AgentId | None = None, - cancellation_token: CancellationToken | None = None, - ) -> Any: - if not self._running: - raise ValueError("Runtime must be running when sending message.") - if self._host_connection is None: - raise RuntimeError("Host connection is not set.") - data_type = self._serialization_registry.type_name(message) - with self._trace_helper.trace_block( - "create", recipient, parent=None, extraAttributes={"message_type": data_type} - ): - # create a new future for the result - future = asyncio.get_event_loop().create_future() - request_id = await self._get_new_request_id() - self._pending_requests[request_id] = future - serialized_message = self._serialization_registry.serialize( - message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE - ) - telemetry_metadata = get_telemetry_grpc_metadata() - runtime_message = agent_worker_pb2.Message( - request=agent_worker_pb2.RpcRequest( - request_id=request_id, - target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key), - source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None, - metadata=telemetry_metadata, - payload=agent_worker_pb2.Payload( - data_type=data_type, - data=serialized_message, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), - ) - ) - - # TODO: Find a way to handle timeouts/errors - task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - return await future - async def publish_message( self, message: Any, @@ -477,99 +423,6 @@ async def _get_new_request_id(self) -> str: self._next_request_id += 1 return str(self._next_request_id) - async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: - assert self._host_connection is not None - recipient = AgentId(request.target.type, request.target.key) - sender: AgentId | None = None - if request.HasField("source"): - sender = AgentId(request.source.type, request.source.key) - logging.info(f"Processing request from {sender} to {recipient}") - else: - logging.info(f"Processing request from unknown source to {recipient}") - - # Deserialize the message. - message = self._serialization_registry.deserialize( - request.payload.data, - type_name=request.payload.data_type, - data_content_type=request.payload.data_content_type, - ) - - # Get the receiving agent and prepare the message context. - rec_agent = await self._get_agent(recipient) - message_context = MessageContext( - sender=sender, - topic_id=None, - is_rpc=True, - cancellation_token=CancellationToken(), - message_id=request.request_id, - ) - - # Call the receiving agent. - try: - with MessageHandlerContext.populate_context(rec_agent.id): - with self._trace_helper.trace_block( - "process", - rec_agent.id, - parent=request.metadata, - attributes={"request_id": request.request_id}, - extraAttributes={"message_type": request.payload.data_type}, - ): - result = await rec_agent.on_message(message, ctx=message_context) - except BaseException as e: - response_message = agent_worker_pb2.Message( - response=agent_worker_pb2.RpcResponse( - request_id=request.request_id, - error=str(e), - metadata=get_telemetry_grpc_metadata(), - ), - ) - # Send the error response. - await self._host_connection.send(response_message) - return - - # Serialize the result. - result_type = self._serialization_registry.type_name(result) - serialized_result = self._serialization_registry.serialize( - result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE - ) - - # Create the response message. - response_message = agent_worker_pb2.Message( - response=agent_worker_pb2.RpcResponse( - request_id=request.request_id, - payload=agent_worker_pb2.Payload( - data_type=result_type, - data=serialized_result, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), - metadata=get_telemetry_grpc_metadata(), - ) - ) - - # Send the response. - await self._host_connection.send(response_message) - - async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None: - with self._trace_helper.trace_block( - "ack", - None, - parent=response.metadata, - attributes={"request_id": response.request_id}, - extraAttributes={"message_type": response.payload.data_type}, - ): - # Deserialize the result. - result = self._serialization_registry.deserialize( - response.payload.data, - type_name=response.payload.data_type, - data_content_type=response.payload.data_content_type, - ) - # Get the future and set the result. - future = self._pending_requests.pop(response.request_id) - if len(response.error) > 0: - future.set_exception(Exception(response.error)) - else: - future.set_result(result) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: event_attributes = event.attributes sender: AgentId | None = None @@ -601,16 +454,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: else: raise ValueError(f"Unsupported message content type: {message_content_type}") - # TODO: dont read these values in the runtime - topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else "" - is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST - is_marked_rpc_type = ( - _constants.MESSAGE_KIND_ATTR in event_attributes - and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST - ) - if is_rpc and not is_marked_rpc_type: - warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2) - # Send the message to each recipient. responses: List[Awaitable[Any]] = [] for agent_id in recipients: @@ -619,7 +462,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: message_context = MessageContext( sender=sender, topic_id=topic_id, - is_rpc=is_rpc, cancellation_token=CancellationToken(), message_id=event.id, ) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 0bb8ae0a8a20..f5f00d7efc59 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -100,18 +100,6 @@ async def _receive_messages( logger.info(f"Received message from client {client_id}: {message}") oneofcase = message.WhichOneof("message") match oneofcase: - case "request": - request: agent_worker_pb2.RpcRequest = message.request - task = asyncio.create_task(self._process_request(request, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - response: agent_worker_pb2.RpcResponse = message.response - task = asyncio.create_task(self._process_response(response, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore @@ -138,43 +126,6 @@ async def _receive_messages( case None: logger.warning("Received empty message") - async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: - # Deliver the message to a client given the target agent type. - async with self._agent_type_to_client_id_lock: - target_client_id = self._agent_type_to_client_id.get(request.target.type) - if target_client_id is None: - logger.error(f"Agent {request.target.type} not found, failed to deliver message.") - return - target_send_queue = self._send_queues.get(target_client_id) - if target_send_queue is None: - logger.error(f"Client {target_client_id} not found, failed to deliver message.") - return - await target_send_queue.put(agent_worker_pb2.Message(request=request)) - - # Create a future to wait for the response from the target. - future = asyncio.get_event_loop().create_future() - self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future - - # Create a task to wait for the response and send it back to the client. - send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) - self._background_tasks.add(send_response_task) - send_response_task.add_done_callback(self._raise_on_exception) - send_response_task.add_done_callback(self._background_tasks.discard) - - async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None: - response = await future - message = agent_worker_pb2.Message(response=response) - send_queue = self._send_queues.get(client_id) - if send_queue is None: - logger.error(f"Client {client_id} not found, failed to send response message.") - return - await send_queue.put(message) - - async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: - # Setting the result of the future will send the response back to the original sender. - future = self._pending_responses[client_id].pop(response.request_id) - future.set_result(response) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: topic_id = TopicId(type=event.type, source=event.source) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py index b4794f1eaba6..85f47310cdee 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xad\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB\x1e\xaa\x02\x1bMicrosoft.AutoGen.Contractsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xdd\x02\n\x07Message\x12\x33\n\ncloudEvent\x18\x01 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x02 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x03 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x04 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x05 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB\x1e\xaa\x02\x1bMicrosoft.AutoGen.Contractsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,52 +24,30 @@ if _descriptor._USE_C_DESCRIPTORS == False: _globals['DESCRIPTOR']._options = None _globals['DESCRIPTOR']._serialized_options = b'\252\002\033Microsoft.AutoGen.Contracts' - _globals['_RPCREQUEST_METADATAENTRY']._options = None - _globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001' - _globals['_RPCRESPONSE_METADATAENTRY']._options = None - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001' - _globals['_EVENT_METADATAENTRY']._options = None - _globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001' - _globals['_TOPICID']._serialized_start=75 - _globals['_TOPICID']._serialized_end=114 - _globals['_AGENTID']._serialized_start=116 - _globals['_AGENTID']._serialized_end=152 - _globals['_PAYLOAD']._serialized_start=154 - _globals['_PAYLOAD']._serialized_end=223 - _globals['_RPCREQUEST']._serialized_start=226 - _globals['_RPCREQUEST']._serialized_end=491 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_start=433 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_end=480 - _globals['_RPCRESPONSE']._serialized_start=494 - _globals['_RPCRESPONSE']._serialized_end=678 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=433 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=480 - _globals['_EVENT']._serialized_start=681 - _globals['_EVENT']._serialized_end=909 - _globals['_EVENT_METADATAENTRY']._serialized_start=433 - _globals['_EVENT_METADATAENTRY']._serialized_end=480 - _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911 - _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971 - _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973 - _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067 - _globals['_TYPESUBSCRIPTION']._serialized_start=1069 - _globals['_TYPESUBSCRIPTION']._serialized_end=1127 - _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=1129 - _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=1200 - _globals['_SUBSCRIPTION']._serialized_start=1203 - _globals['_SUBSCRIPTION']._serialized_end=1353 - _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1355 - _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1443 - _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1445 - _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1537 - _globals['_AGENTSTATE']._serialized_start=1540 - _globals['_AGENTSTATE']._serialized_end=1697 - _globals['_GETSTATERESPONSE']._serialized_start=1699 - _globals['_GETSTATERESPONSE']._serialized_end=1805 - _globals['_SAVESTATERESPONSE']._serialized_start=1807 - _globals['_SAVESTATERESPONSE']._serialized_end=1873 - _globals['_MESSAGE']._serialized_start=1876 - _globals['_MESSAGE']._serialized_end=2305 - _globals['_AGENTRPC']._serialized_start=2308 - _globals['_AGENTRPC']._serialized_end=2486 + _globals['_AGENTID']._serialized_start=75 + _globals['_AGENTID']._serialized_end=111 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=113 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=173 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=175 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=269 + _globals['_TYPESUBSCRIPTION']._serialized_start=271 + _globals['_TYPESUBSCRIPTION']._serialized_end=329 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=331 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=402 + _globals['_SUBSCRIPTION']._serialized_start=405 + _globals['_SUBSCRIPTION']._serialized_end=555 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=557 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=645 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=647 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=739 + _globals['_AGENTSTATE']._serialized_start=742 + _globals['_AGENTSTATE']._serialized_end=899 + _globals['_GETSTATERESPONSE']._serialized_start=901 + _globals['_GETSTATERESPONSE']._serialized_end=1007 + _globals['_SAVESTATERESPONSE']._serialized_start=1009 + _globals['_SAVESTATERESPONSE']._serialized_end=1075 + _globals['_MESSAGE']._serialized_start=1078 + _globals['_MESSAGE']._serialized_end=1427 + _globals['_AGENTRPC']._serialized_start=1430 + _globals['_AGENTRPC']._serialized_end=1608 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi index 79e384ab948b..7c9baa5e9ca7 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi @@ -5,33 +5,13 @@ isort:skip_file import builtins import cloudevent_pb2 -import collections.abc import google.protobuf.any_pb2 import google.protobuf.descriptor -import google.protobuf.internal.containers import google.protobuf.message import typing DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -@typing.final -class TopicId(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TYPE_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - type: builtins.str - source: builtins.str - def __init__( - self, - *, - type: builtins.str = ..., - source: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["source", b"source", "type", b"type"]) -> None: ... - -global___TopicId = TopicId - @typing.final class AgentId(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -50,170 +30,6 @@ class AgentId(google.protobuf.message.Message): global___AgentId = AgentId -@typing.final -class Payload(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - DATA_TYPE_FIELD_NUMBER: builtins.int - DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - data_type: builtins.str - data_content_type: builtins.str - data: builtins.bytes - def __init__( - self, - *, - data_type: builtins.str = ..., - data_content_type: builtins.str = ..., - data: builtins.bytes = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ... - -global___Payload = Payload - -@typing.final -class RpcRequest(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - REQUEST_ID_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - TARGET_FIELD_NUMBER: builtins.int - METHOD_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - request_id: builtins.str - method: builtins.str - @property - def source(self) -> global___AgentId: ... - @property - def target(self) -> global___AgentId: ... - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - request_id: builtins.str = ..., - source: global___AgentId | None = ..., - target: global___AgentId | None = ..., - method: builtins.str = ..., - payload: global___Payload | None = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... - -global___RpcRequest = RpcRequest - -@typing.final -class RpcResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - REQUEST_ID_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - ERROR_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - request_id: builtins.str - error: builtins.str - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - request_id: builtins.str = ..., - payload: global___Payload | None = ..., - error: builtins.str = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ... - -global___RpcResponse = RpcResponse - -@typing.final -class Event(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - TOPIC_TYPE_FIELD_NUMBER: builtins.int - TOPIC_SOURCE_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - topic_type: builtins.str - topic_source: builtins.str - @property - def source(self) -> global___AgentId: ... - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - topic_type: builtins.str = ..., - topic_source: builtins.str = ..., - source: global___AgentId | None = ..., - payload: global___Payload | None = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... - -global___Event = Event - @typing.final class RegisterAgentTypeRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -435,18 +251,12 @@ global___SaveStateResponse = SaveStateResponse class Message(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - REQUEST_FIELD_NUMBER: builtins.int - RESPONSE_FIELD_NUMBER: builtins.int CLOUDEVENT_FIELD_NUMBER: builtins.int REGISTERAGENTTYPEREQUEST_FIELD_NUMBER: builtins.int REGISTERAGENTTYPERESPONSE_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONREQUEST_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONRESPONSE_FIELD_NUMBER: builtins.int @property - def request(self) -> global___RpcRequest: ... - @property - def response(self) -> global___RpcResponse: ... - @property def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... @property def registerAgentTypeRequest(self) -> global___RegisterAgentTypeRequest: ... @@ -459,16 +269,14 @@ class Message(google.protobuf.message.Message): def __init__( self, *, - request: global___RpcRequest | None = ..., - response: global___RpcResponse | None = ..., cloudEvent: cloudevent_pb2.CloudEvent | None = ..., registerAgentTypeRequest: global___RegisterAgentTypeRequest | None = ..., registerAgentTypeResponse: global___RegisterAgentTypeResponse | None = ..., addSubscriptionRequest: global___AddSubscriptionRequest | None = ..., addSubscriptionResponse: global___AddSubscriptionResponse | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... + def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... global___Message = Message diff --git a/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py b/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py index b9a7d4ceeb41..16df4711fa56 100644 --- a/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py +++ b/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py @@ -216,27 +216,22 @@ async def test_web_surfer_oai() -> None: ) ), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="Please scroll down.", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="Please scroll up.", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="When was it founded?", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="What's this page about?", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.stop_when_idle()