Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Declarative BaseChat Agents #5055

Merged
merged 18 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core.memory import Memory
from autogen_core.model_context import (
ChatCompletionContext,
Expand All @@ -28,6 +28,8 @@
UserMessage,
)
from autogen_core.tools import FunctionTool, Tool
from pydantic import BaseModel
from typing_extensions import Self

from .. import EVENT_LOGGER_NAME
from ..base import Handoff as HandoffBase
Expand All @@ -49,7 +51,21 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class AssistantAgent(BaseChatAgent):
class AssistantAgentConfig(BaseModel):
"""The declarative configuration for the assistant agent."""

name: str
model_client: ComponentModel
# tools: List[Any] | None = None # TBD
handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None
victordibia marked this conversation as resolved.
Show resolved Hide resolved
description: str
system_message: str | None = None
reflect_on_tool_use: bool
tool_call_summary_format: str


class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
"""An agent that provides assistance with tool use.

The :meth:`on_messages` returns a :class:`~autogen_agentchat.base.Response`
Expand Down Expand Up @@ -229,6 +245,9 @@ async def main() -> None:
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
"""

component_config_schema = AssistantAgentConfig
component_provider_override = "autogen_agentchat.agents.AssistantAgent"

def __init__(
self,
name: str,
Expand Down Expand Up @@ -462,3 +481,40 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
assistant_agent_state = AssistantAgentState.model_validate(state)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)

def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""

# raise an error if tools is not empty until it is implemented
# TBD : Implement serializing tools and remove this check.
if self._tools and len(self._tools) > 0:
raise NotImplementedError("Serializing tools is not implemented yet.")

return AssistantAgentConfig(
name=self.name,
model_client=self._model_client.dump_component(),
# tools=[], # TBD
handoffs=list(self._handoffs.values()),
model_context=self._model_context.dump_component(),
description=self.description,
system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str)
else None,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
)

@classmethod
def _from_config(cls, config: AssistantAgentConfig) -> Self:
"""Create an assistant agent from a declarative config."""
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
victordibia marked this conversation as resolved.
Show resolved Hide resolved
# tools=[], # TBD
handoffs=config.handoffs,
model_context=None,
description=config.description,
system_message=config.system_message,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken
from autogen_core import CancellationToken, ComponentBase
from pydantic import BaseModel

from ..base import ChatAgent, Response, TaskResult
from ..messages import (
Expand All @@ -13,7 +14,7 @@
from ..state import BaseState


class BaseChatAgent(ChatAgent, ABC):
class BaseChatAgent(ChatAgent, ABC, ComponentBase[BaseModel]):
"""Base class for a chat agent.

This abstract class provides a base implementation for a :class:`ChatAgent`.
Expand All @@ -35,6 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
This design principle must be followed when creating a new agent.
"""

component_type = "agent"

def __init__(self, name: str, description: str) -> None:
self._name = name
if self._name.isidentifier() is False:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from inspect import iscoroutinefunction
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast

from autogen_core import CancellationToken
from autogen_core import CancellationToken, Component
from pydantic import BaseModel
from typing_extensions import Self

from ..base import Response
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
Expand All @@ -24,7 +26,15 @@
return await task


class UserProxyAgent(BaseChatAgent):
class UserProxyAgentConfig(BaseModel):
"""Declarative configuration for the UserProxyAgent."""

name: str
description: str = "A human user"
input_func: str | None = None
victordibia marked this conversation as resolved.
Show resolved Hide resolved


class UserProxyAgent(BaseChatAgent, Component[UserProxyAgentConfig]):
"""An agent that can represent a human user through an input function.

This agent can be used to represent a human user in a chat system by providing a custom input function.
Expand Down Expand Up @@ -109,6 +119,10 @@
print(f"BaseException: {e}")
"""

component_type = "agent"
component_provider_override = "autogen_agentchat.agents.UserProxyAgent"
component_config_schema = UserProxyAgentConfig

class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(
Expand Down Expand Up @@ -218,3 +232,11 @@
async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
"""Reset agent state."""
pass

def _to_config(self) -> UserProxyAgentConfig:
# TODO: Add ability to serialie input_func
return UserProxyAgentConfig(name=self.name, description=self.description, input_func=None)

Check warning on line 238 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L238

Added line #L238 was not covered by tests

@classmethod
def _from_config(cls, config: UserProxyAgentConfig) -> Self:
return cls(name=config.name, description=config.description, input_func=None)

Check warning on line 242 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L242

Added line #L242 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ async def main() -> None:
"""

component_type = "termination"
# component_config_schema = BaseModel # type: ignore

@property
@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class StopMessageTerminationConfig(BaseModel):
class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]):
"""Terminate the conversation if a StopMessage is received."""

component_type = "termination"
component_config_schema = StopMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.StopMessageTermination"

Expand Down Expand Up @@ -58,7 +57,6 @@ class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminatio
max_messages: The maximum number of messages allowed in the conversation.
"""

component_type = "termination"
component_config_schema = MaxMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination"

Expand Down Expand Up @@ -104,7 +102,6 @@ class TextMentionTermination(TerminationCondition, Component[TextMentionTerminat
text: The text to look for in the messages.
"""

component_type = "termination"
component_config_schema = TextMentionTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TextMentionTermination"

Expand Down Expand Up @@ -159,7 +156,6 @@ class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminatio
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""

component_type = "termination"
component_config_schema = TokenUsageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TokenUsageTermination"

Expand Down Expand Up @@ -234,7 +230,6 @@ class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfi
target (str): The target of the handoff message.
"""

component_type = "termination"
component_config_schema = HandoffTerminationConfig
component_provider_override = "autogen_agentchat.conditions.HandoffTermination"

Expand Down Expand Up @@ -279,7 +274,6 @@ class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfi
timeout_seconds: The maximum duration in seconds before terminating the conversation.
"""

component_type = "termination"
component_config_schema = TimeoutTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TimeoutTermination"

Expand Down Expand Up @@ -339,7 +333,6 @@ class ExternalTermination(TerminationCondition, Component[ExternalTerminationCon

"""

component_type = "termination"
component_config_schema = ExternalTerminationConfig
component_provider_override = "autogen_agentchat.conditions.ExternalTermination"

Expand Down Expand Up @@ -389,7 +382,6 @@ class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminat
TerminatedException: If the termination condition has already been reached.
"""

component_type = "termination"
component_config_schema = SourceMatchTerminationConfig
component_provider_override = "autogen_agentchat.conditions.SourceMatchTermination"

Expand Down
48 changes: 48 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,51 @@ class BadMemory:

assert not isinstance(BadMemory(), Memory)
assert isinstance(ListMemory(), Memory)


@pytest.mark.asyncio
async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
)

agent_config = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"

agent2 = AssistantAgent.load_component(agent_config)
assert agent2.name == agent.name

agent3 = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
with pytest.raises(NotImplementedError):
agent3.dump_component()
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
TokenUsageTermination,
)
from autogen_core import ComponentLoader, ComponentModel
from autogen_core.model_context import (
BufferedChatCompletionContext,
HeadAndTailChatCompletionContext,
UnboundedChatCompletionContext,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -92,3 +97,35 @@ async def test_termination_declarative() -> None:
# Test loading complex composition
loaded_composite = ComponentLoader.load_component(composite_config)
assert isinstance(loaded_composite, AndTerminationCondition)


@pytest.mark.asyncio
async def test_chat_completion_context_declarative() -> None:
unbounded_context = UnboundedChatCompletionContext()
buffered_context = BufferedChatCompletionContext(buffer_size=5)
head_tail_context = HeadAndTailChatCompletionContext(head_size=3, tail_size=2)

# Test serialization
unbounded_config = unbounded_context.dump_component()
assert unbounded_config.provider == "autogen_core.model_context.UnboundedChatCompletionContext"

buffered_config = buffered_context.dump_component()
assert buffered_config.provider == "autogen_core.model_context.BufferedChatCompletionContext"
assert buffered_config.config["buffer_size"] == 5

head_tail_config = head_tail_context.dump_component()
assert head_tail_config.provider == "autogen_core.model_context.HeadAndTailChatCompletionContext"
assert head_tail_config.config["head_size"] == 3
assert head_tail_config.config["tail_size"] == 2

# Test deserialization
loaded_unbounded = ComponentLoader.load_component(unbounded_config, UnboundedChatCompletionContext)
assert isinstance(loaded_unbounded, UnboundedChatCompletionContext)

loaded_buffered = ComponentLoader.load_component(buffered_config, BufferedChatCompletionContext)

assert isinstance(loaded_buffered, BufferedChatCompletionContext)

loaded_head_tail = ComponentLoader.load_component(head_tail_config, HeadAndTailChatCompletionContext)

assert isinstance(loaded_head_tail, HeadAndTailChatCompletionContext)
victordibia marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
Loading
Loading