Skip to content

Commit

Permalink
enable to use tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dewmal committed Feb 5, 2025
1 parent d23e4f9 commit b1e8dcf
Show file tree
Hide file tree
Showing 4 changed files with 463 additions and 102 deletions.
183 changes: 167 additions & 16 deletions bindings/ceylon/ceylon/llm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any
from typing import Dict, Any, Optional, List, Sequence

from pydantic import BaseModel

from ceylon.llm.models import Model, ModelSettings, ModelMessage
from ceylon.llm.models.support.messages import MessageRole, TextPart
from ceylon.llm.models.support.messages import (
MessageRole,
TextPart,
ToolCallPart,
ToolReturnPart,
ModelMessagePart
)
from ceylon.llm.models.support.tools import ToolDefinition
from ceylon.processor.agent import ProcessWorker
from ceylon.processor.data import ProcessRequest

Expand All @@ -30,6 +37,8 @@ class LLMConfig(BaseModel):
retry_attempts: int = 3
retry_delay: float = 1.0
timeout: float = 30.0
tools: Optional[Sequence[ToolDefinition]] = None
parallel_tool_calls: Optional[int] = None

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -58,33 +67,175 @@ def __init__(
self.response_cache: Dict[str, LLMResponse] = {}
self.processing_lock = asyncio.Lock()

# Initialize model context with settings
# Initialize model context with settings and tools
self.model_context = self.llm_model.create_context(
settings=ModelSettings(
temperature=config.temperature,
max_tokens=config.max_tokens
)
max_tokens=config.max_tokens,
parallel_tool_calls=config.parallel_tool_calls
),
tools=config.tools or []
)

async def _processor(self, request: ProcessRequest, time: int):
async def _process_tool_calls(
self,
message_parts: List[ModelMessagePart]
) -> List[ModelMessagePart]:
"""Process any tool calls in the message parts and return updated parts."""
processed_parts = []

for part in message_parts:
if isinstance(part, ToolCallPart):
try:
# Find the corresponding tool
tool = next(
(t for t in self.config.tools or []
if t.name == part.tool_name),
None
)

if tool:
# Execute the tool
result = await tool.function(**part.args)

# Add the tool return
processed_parts.append(
ToolReturnPart(
tool_name=part.tool_name,
content=result
)
)
else:
# Tool not found - add error message
processed_parts.append(
TextPart(
text=f"Error: Tool '{part.tool_name}' not found"
)
)
except Exception as e:
# Handle tool execution error
processed_parts.append(
TextPart(
text=f"Error executing tool '{part.tool_name}': {str(e)}"
)
)
else:
processed_parts.append(part)

return processed_parts

async def _process_conversation(
self,
messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Process a conversation, handling tool calls as needed."""
processed_messages = []

for message in messages:
if message.role == MessageRole.ASSISTANT:
# Process any tool calls in assistant messages
processed_parts = await self._process_tool_calls(message.parts)
processed_messages.append(
ModelMessage(
role=message.role,
parts=processed_parts
)
)
else:
processed_messages.append(message)

return processed_messages

def _parse_request_data(self, data: Any) -> str:
"""Parse the request data into a string format."""
if isinstance(data, str):
return data
elif isinstance(data, dict):
return data.get("request", str(data))
else:
return str(data)

async def _processor(self, request: ProcessRequest, time: int) -> tuple[str, Dict[str, Any]]:
"""Process a request using the LLM model."""
# Initialize conversation with system prompt
message_list = [
ModelMessage(
role=MessageRole.SYSTEM,
parts=[
TextPart(text=self.config.system_prompt)
]
),
parts=[TextPart(text=self.config.system_prompt)]
)
]

# Add user message
user_text = self._parse_request_data(request.data)
message_list.append(
ModelMessage(
role=MessageRole.USER,
parts=[
TextPart(text=request.data)
]
parts=[TextPart(text=user_text)]
)
]
)

# Track the complete conversation
complete_conversation = message_list.copy()
final_response = None
metadata = {}

for attempt in range(self.config.retry_attempts):
try:
# Get model response
response, usage = await self.llm_model.request(
message_list,
self.model_context
)

# Add model response to conversation
assistant_message = ModelMessage(
role=MessageRole.ASSISTANT,
parts=response.parts
)
complete_conversation.append(assistant_message)

# Process any tool calls
complete_conversation = await self._process_conversation(
complete_conversation
)

# Extract final text response
final_text_parts = [
part.text for part in response.parts
if isinstance(part, TextPart)
]
final_response = " ".join(final_text_parts)

# Update metadata
metadata.update({
"usage": {
"requests": usage.requests,
"request_tokens": usage.request_tokens,
"response_tokens": usage.response_tokens,
"total_tokens": usage.total_tokens
},
"attempt": attempt + 1,
"tools_used": [
part.tool_name for part in response.parts
if isinstance(part, ToolCallPart)
]
})

# If we got a response, break the retry loop
if final_response:
break

except Exception as e:
if attempt == self.config.retry_attempts - 1:
raise
await asyncio.sleep(self.config.retry_delay)

if not final_response:
raise ValueError("No valid response generated")

return await self.llm_model.request(message_list, self.model_context)
return final_response, metadata

async def stop(self) -> None:
if self.llm_model:
await self.llm_model.close()
await super().stop()
await super().stop()
132 changes: 129 additions & 3 deletions bindings/ceylon/ceylon/llm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@
from ceylon.llm.models.support.settings import ModelSettings
from ceylon.llm.models.support.tools import ToolDefinition
from ceylon.llm.models.support.usage import Usage, UsageLimits
from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import TracebackType
from typing import AsyncIterator, Optional, Sequence, Type, Any
import re
import json

from ceylon.llm.models.support.http import AsyncHTTPClient, cached_async_http_client
from ceylon.llm.models.support.messages import (
ModelMessage,
ModelResponse,
StreamedResponse,
MessageRole,
TextPart,
ToolCallPart,
ToolReturnPart,
ModelMessagePart
)
from ceylon.llm.models.support.settings import ModelSettings
from ceylon.llm.models.support.tools import ToolDefinition
from ceylon.llm.models.support.usage import Usage, UsageLimits


@dataclass
Expand All @@ -27,7 +48,13 @@ class ModelContext:


class Model(ABC):
"""Base class for all language model implementations"""
"""Base class for all language model implementations with tool support"""

# Regex pattern for extracting tool calls - can be overridden by subclasses
TOOL_CALL_PATTERN = re.compile(
r'<tool_call>(?P<tool_json>.*?)</tool_call>',
re.DOTALL
)

def __init__(
self,
Expand Down Expand Up @@ -147,11 +174,108 @@ def _check_usage_limits(self, usage: Usage, limits: UsageLimits) -> None:
raise UsageLimitExceeded(
f"Request limit {limits.request_limit} exceeded"
)
if limits.total_tokens and usage.total_tokens >= limits.total_tokens:
if limits.request_tokens_limit and usage.request_tokens >= limits.request_tokens_limit:
raise UsageLimitExceeded(
f"Request tokens limit {limits.request_tokens_limit} exceeded"
)
if limits.response_tokens_limit and usage.response_tokens >= limits.response_tokens_limit:
raise UsageLimitExceeded(
f"Total token limit {limits.total_tokens} exceeded"
f"Response tokens limit {limits.response_tokens_limit} exceeded"
)

def _format_tool_descriptions(self, tools: Sequence[ToolDefinition]) -> str:
"""Format tool descriptions for system message.
Args:
tools: Sequence of tool definitions
Returns:
Formatted tool descriptions string
"""
if not tools:
return ""

tool_descriptions = []
for tool in tools:
desc = f"- {tool.name}: {tool.description}\n"
desc += f" Parameters: {json.dumps(tool.parameters_json_schema)}"
tool_descriptions.append(desc)

return (
"You have access to the following tools:\n\n"
f"{chr(10).join(tool_descriptions)}\n\n"
"To use a tool, respond with XML tags like this:\n"
"<tool_call>{\"tool_name\": \"tool_name\", \"args\": {\"arg1\": \"value1\"}}</tool_call>\n"
"Wait for the tool result before continuing."
)

def _parse_tool_call(self, match: re.Match) -> Optional[ToolCallPart]:
"""Parse a tool call match into a ToolCallPart.
Args:
match: Regex match object containing tool call JSON
Returns:
ToolCallPart if valid, None if invalid
"""
try:
tool_data = json.loads(match.group('tool_json'))
if isinstance(tool_data, dict) and 'tool_name' in tool_data and 'args' in tool_data:
return ToolCallPart(
tool_name=tool_data['tool_name'],
args=tool_data['args']
)
except (json.JSONDecodeError, KeyError):
pass
return None

def _parse_response(self, text: str) -> list[ModelMessagePart]:
"""Parse response text into message parts.
Args:
text: Raw response text from model
Returns:
List of ModelMessagePart objects
"""
parts = []
current_text = []
last_end = 0

# Find all tool calls in the response
for match in self.TOOL_CALL_PATTERN.finditer(text):
# Add any text before the tool call
if match.start() > last_end:
prefix_text = text[last_end:match.start()].strip()
if prefix_text:
current_text.append(prefix_text)

# Parse and add the tool call
tool_call = self._parse_tool_call(match)
if tool_call:
# If we have accumulated text, add it first
if current_text:
parts.append(TextPart(text=' '.join(current_text)))
current_text = []
parts.append(tool_call)
else:
# If tool call parsing fails, treat it as regular text
current_text.append(match.group(0))

last_end = match.end()

# Add any remaining text after the last tool call
if last_end < len(text):
remaining = text[last_end:].strip()
if remaining:
current_text.append(remaining)

# Add any accumulated text as final part
if current_text:
parts.append(TextPart(text=' '.join(current_text)))

return parts


class UsageLimitExceeded(Exception):
"""Raised when usage limits are exceeded"""
Expand All @@ -178,10 +302,12 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5,
The default timeouts match those of OpenAI,
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
"""

def factory() -> httpx.AsyncClient:
return httpx.AsyncClient(
headers={"User-Agent": get_user_agent()},
timeout=timeout,
base_url=base_url
)

return factory
Loading

0 comments on commit b1e8dcf

Please sign in to comment.