Skip to content

Commit

Permalink
Declarative BaseChat Agents (#5055)
Browse files Browse the repository at this point in the history
* v1, make assistant agent declarative

* make head tail context declarative

* update and formatting

* update assistant, format updates

* make websurfer declarative

* update formatting

* move declarative docs to advanced section

* remove tools until implemented

* minor updates to termination conditions

* update docs
  • Loading branch information
victordibia authored Jan 17, 2025
1 parent 1f22a7b commit c2a43e8
Show file tree
Hide file tree
Showing 17 changed files with 524 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core.memory import Memory
from autogen_core.model_context import (
ChatCompletionContext,
Expand All @@ -28,6 +28,8 @@
UserMessage,
)
from autogen_core.tools import FunctionTool, Tool
from pydantic import BaseModel
from typing_extensions import Self

from .. import EVENT_LOGGER_NAME
from ..base import Handoff as HandoffBase
Expand All @@ -49,7 +51,21 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class AssistantAgent(BaseChatAgent):
class AssistantAgentConfig(BaseModel):
"""The declarative configuration for the assistant agent."""

name: str
model_client: ComponentModel
# tools: List[Any] | None = None # TBD
handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None
description: str
system_message: str | None = None
reflect_on_tool_use: bool
tool_call_summary_format: str


class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
"""An agent that provides assistance with tool use.
The :meth:`on_messages` returns a :class:`~autogen_agentchat.base.Response`
Expand Down Expand Up @@ -229,6 +245,9 @@ async def main() -> None:
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
"""

component_config_schema = AssistantAgentConfig
component_provider_override = "autogen_agentchat.agents.AssistantAgent"

def __init__(
self,
name: str,
Expand Down Expand Up @@ -462,3 +481,40 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
assistant_agent_state = AssistantAgentState.model_validate(state)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)

def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""

# raise an error if tools is not empty until it is implemented
# TBD : Implement serializing tools and remove this check.
if self._tools and len(self._tools) > 0:
raise NotImplementedError("Serializing tools is not implemented yet.")

return AssistantAgentConfig(
name=self.name,
model_client=self._model_client.dump_component(),
# tools=[], # TBD
handoffs=list(self._handoffs.values()),
model_context=self._model_context.dump_component(),
description=self.description,
system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str)
else None,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
)

@classmethod
def _from_config(cls, config: AssistantAgentConfig) -> Self:
"""Create an assistant agent from a declarative config."""
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
# tools=[], # TBD
handoffs=config.handoffs,
model_context=None,
description=config.description,
system_message=config.system_message,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken
from autogen_core import CancellationToken, ComponentBase
from pydantic import BaseModel

from ..base import ChatAgent, Response, TaskResult
from ..messages import (
Expand All @@ -13,7 +14,7 @@
from ..state import BaseState


class BaseChatAgent(ChatAgent, ABC):
class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
"""Base class for a chat agent.
This abstract class provides a base implementation for a :class:`ChatAgent`.
Expand All @@ -35,6 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
This design principle must be followed when creating a new agent.
"""

component_type = "agent"

def __init__(self, name: str, description: str) -> None:
self._name = name
if self._name.isidentifier() is False:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from inspect import iscoroutinefunction
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast

from autogen_core import CancellationToken
from autogen_core import CancellationToken, Component
from pydantic import BaseModel
from typing_extensions import Self

from ..base import Response
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
Expand All @@ -24,7 +26,15 @@ async def cancellable_input(prompt: str, cancellation_token: Optional[Cancellati
return await task


class UserProxyAgent(BaseChatAgent):
class UserProxyAgentConfig(BaseModel):
"""Declarative configuration for the UserProxyAgent."""

name: str
description: str = "A human user"
input_func: str | None = None


class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
"""An agent that can represent a human user through an input function.
This agent can be used to represent a human user in a chat system by providing a custom input function.
Expand Down Expand Up @@ -109,6 +119,10 @@ async def cancellable_user_agent():
print(f"BaseException: {e}")
"""

component_type = "agent"
component_provider_override = "autogen_agentchat.agents.UserProxyAgent"
component_config_schema = UserProxyAgentConfig

class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(
Expand Down Expand Up @@ -218,3 +232,11 @@ async def on_messages_stream(
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
"""Reset agent state."""
pass

def _to_config(self) -> UserProxyAgentConfig:
# TODO: Add ability to serialie input_func
return UserProxyAgentConfig(name=self.name, description=self.description, input_func=None)

@classmethod
def _from_config(cls, config: UserProxyAgentConfig) -> Self:
return cls(name=config.name, description=config.description, input_func=None)
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ async def main() -> None:
"""

component_type = "termination"
# component_config_schema = BaseModel # type: ignore

@property
@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class StopMessageTerminationConfig(BaseModel):
class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]):
"""Terminate the conversation if a StopMessage is received."""

component_type = "termination"
component_config_schema = StopMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.StopMessageTermination"

Expand Down Expand Up @@ -58,7 +57,6 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
max_messages: The maximum number of messages allowed in the conversation.
"""

component_type = "termination"
component_config_schema = MaxMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination"

Expand Down Expand Up @@ -104,7 +102,6 @@ class TextMentionTermination(TerminationCondition, Component[TextMentionTerminat
text: The text to look for in the messages.
"""

component_type = "termination"
component_config_schema = TextMentionTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TextMentionTermination"

Expand Down Expand Up @@ -159,7 +156,6 @@ class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminatio
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""

component_type = "termination"
component_config_schema = TokenUsageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TokenUsageTermination"

Expand Down Expand Up @@ -234,7 +230,6 @@ class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfi
target (str): The target of the handoff message.
"""

component_type = "termination"
component_config_schema = HandoffTerminationConfig
component_provider_override = "autogen_agentchat.conditions.HandoffTermination"

Expand Down Expand Up @@ -279,7 +274,6 @@ class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfi
timeout_seconds: The maximum duration in seconds before terminating the conversation.
"""

component_type = "termination"
component_config_schema = TimeoutTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TimeoutTermination"

Expand Down Expand Up @@ -339,7 +333,6 @@ class ExternalTermination(TerminationCondition, Component[ExternalTerminationCon
"""

component_type = "termination"
component_config_schema = ExternalTerminationConfig
component_provider_override = "autogen_agentchat.conditions.ExternalTermination"

Expand Down Expand Up @@ -389,7 +382,6 @@ class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminat
TerminatedException: If the termination condition has already been reached.
"""

component_type = "termination"
component_config_schema = SourceMatchTerminationConfig
component_provider_override = "autogen_agentchat.conditions.SourceMatchTermination"

Expand Down
48 changes: 48 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,51 @@ class BadMemory:

assert not isinstance(BadMemory(), Memory)
assert isinstance(ListMemory(), Memory)


@pytest.mark.asyncio
async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
)

agent_config = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"

agent2 = AssistantAgent.load_component(agent_config)
assert agent2.name == agent.name

agent3 = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
with pytest.raises(NotImplementedError):
agent3.dump_component()
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
TokenUsageTermination,
)
from autogen_core import ComponentLoader, ComponentModel
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -92,3 +97,35 @@ async def test_termination_declarative() -> None:
# Test loading complex composition
loaded_composite = ComponentLoader.load_component(composite_config)
assert isinstance(loaded_composite, AndTerminationCondition)


@pytest.mark.asyncio
async def test_chat_completion_context_declarative() -> None:
unbounded_context = UnboundedChatCompletionContext()
buffered_context = BufferedChatCompletionContext(buffer_size=5)
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)

# Test serialization
unbounded_config = unbounded_context.dump_component()
assert unbounded_config.provider == "autogen_core.model_context.UnboundedChatCompletionContext"

buffered_config = buffered_context.dump_component()
assert buffered_config.provider == "autogen_core.model_context.BufferedChatCompletionContext"
assert buffered_config.config["buffer_size"] == 5

head_tail_config = head_tail_context.dump_component()
assert head_tail_config.provider == "autogen_core.model_context.HeadAndTailChatCompletionContext"
assert head_tail_config.config["head_size"] == 3
assert head_tail_config.config["tail_size"] == 2

# Test deserialization
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
assert isinstance(loaded_unbounded, UnboundedChatCompletionContext)

loaded_buffered = ComponentLoader.load_component(buffered_config, BufferedChatCompletionContext)

assert isinstance(loaded_buffered, BufferedChatCompletionContext)

loaded_head_tail = ComponentLoader.load_component(head_tail_config, HeadAndTailChatCompletionContext)

assert isinstance(loaded_head_tail, HeadAndTailChatCompletionContext)
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ Sample code and use cases

How to migrate from AutoGen 0.2.x to 0.4.x.
:::

:::{grid-item-card} {fas}`save;pst-color-primary` Serialize Components
:link: ./serialize-components.html

Serialize and deserialize components
:::

:::{grid-item-card} {fas}`brain;pst-color-primary` Memory
:link: ./memory.html

Add memory capabilities to your agents
:::
::::

```{toctree}
Expand All @@ -91,8 +103,7 @@ tutorial/human-in-the-loop
tutorial/termination
tutorial/custom-agents
tutorial/state
tutorial/declarative
tutorial/memory
```

```{toctree}
Expand All @@ -103,6 +114,8 @@ tutorial/memory
selector-group-chat
swarm
magentic-one
memory
serialize-components
```

```{toctree}
Expand Down
Loading

0 comments on commit c2a43e8

Please sign in to comment.