diff --git a/bindings/ceylon/ceylon/llm/agent.py b/bindings/ceylon/ceylon/llm/agent.py
index fec2983..4fe573e 100644
--- a/bindings/ceylon/ceylon/llm/agent.py
+++ b/bindings/ceylon/ceylon/llm/agent.py
@@ -5,12 +5,19 @@
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
-from typing import Dict, Any
+from typing import Dict, Any, Optional, List, Sequence
from pydantic import BaseModel
from ceylon.llm.models import Model, ModelSettings, ModelMessage
-from ceylon.llm.models.support.messages import MessageRole, TextPart
+from ceylon.llm.models.support.messages import (
+ MessageRole,
+ TextPart,
+ ToolCallPart,
+ ToolReturnPart,
+ ModelMessagePart
+)
+from ceylon.llm.models.support.tools import ToolDefinition
from ceylon.processor.agent import ProcessWorker
from ceylon.processor.data import ProcessRequest
@@ -30,6 +37,8 @@ class LLMConfig(BaseModel):
retry_attempts: int = 3
retry_delay: float = 1.0
timeout: float = 30.0
+ tools: Optional[Sequence[ToolDefinition]] = None
+ parallel_tool_calls: Optional[int] = None
class Config:
arbitrary_types_allowed = True
@@ -58,33 +67,175 @@ def __init__(
self.response_cache: Dict[str, LLMResponse] = {}
self.processing_lock = asyncio.Lock()
- # Initialize model context with settings
+ # Initialize model context with settings and tools
self.model_context = self.llm_model.create_context(
settings=ModelSettings(
temperature=config.temperature,
- max_tokens=config.max_tokens
- )
+ max_tokens=config.max_tokens,
+ parallel_tool_calls=config.parallel_tool_calls
+ ),
+ tools=config.tools or []
)
- async def _processor(self, request: ProcessRequest, time: int):
+ async def _process_tool_calls(
+ self,
+ message_parts: List[ModelMessagePart]
+ ) -> List[ModelMessagePart]:
+ """Process any tool calls in the message parts and return updated parts."""
+ processed_parts = []
+
+ for part in message_parts:
+ if isinstance(part, ToolCallPart):
+ try:
+ # Find the corresponding tool
+ tool = next(
+ (t for t in self.config.tools or []
+ if t.name == part.tool_name),
+ None
+ )
+
+ if tool:
+ # Execute the tool
+ result = await tool.function(**part.args)
+
+ # Add the tool return
+ processed_parts.append(
+ ToolReturnPart(
+ tool_name=part.tool_name,
+ content=result
+ )
+ )
+ else:
+ # Tool not found - add error message
+ processed_parts.append(
+ TextPart(
+ text=f"Error: Tool '{part.tool_name}' not found"
+ )
+ )
+ except Exception as e:
+ # Handle tool execution error
+ processed_parts.append(
+ TextPart(
+ text=f"Error executing tool '{part.tool_name}': {str(e)}"
+ )
+ )
+ else:
+ processed_parts.append(part)
+
+ return processed_parts
+
+ async def _process_conversation(
+ self,
+ messages: List[ModelMessage]
+ ) -> List[ModelMessage]:
+ """Process a conversation, handling tool calls as needed."""
+ processed_messages = []
+
+ for message in messages:
+ if message.role == MessageRole.ASSISTANT:
+ # Process any tool calls in assistant messages
+ processed_parts = await self._process_tool_calls(message.parts)
+ processed_messages.append(
+ ModelMessage(
+ role=message.role,
+ parts=processed_parts
+ )
+ )
+ else:
+ processed_messages.append(message)
+
+ return processed_messages
+
+ def _parse_request_data(self, data: Any) -> str:
+ """Parse the request data into a string format."""
+ if isinstance(data, str):
+ return data
+ elif isinstance(data, dict):
+ return data.get("request", str(data))
+ else:
+ return str(data)
+
+ async def _processor(self, request: ProcessRequest, time: int) -> tuple[str, Dict[str, Any]]:
+ """Process a request using the LLM model."""
+ # Initialize conversation with system prompt
message_list = [
ModelMessage(
role=MessageRole.SYSTEM,
- parts=[
- TextPart(text=self.config.system_prompt)
- ]
- ),
+ parts=[TextPart(text=self.config.system_prompt)]
+ )
+ ]
+
+ # Add user message
+ user_text = self._parse_request_data(request.data)
+ message_list.append(
ModelMessage(
role=MessageRole.USER,
- parts=[
- TextPart(text=request.data)
- ]
+ parts=[TextPart(text=user_text)]
)
- ]
+ )
+
+ # Track the complete conversation
+ complete_conversation = message_list.copy()
+ final_response = None
+ metadata = {}
+
+ for attempt in range(self.config.retry_attempts):
+ try:
+ # Get model response
+ response, usage = await self.llm_model.request(
+ message_list,
+ self.model_context
+ )
+
+ # Add model response to conversation
+ assistant_message = ModelMessage(
+ role=MessageRole.ASSISTANT,
+ parts=response.parts
+ )
+ complete_conversation.append(assistant_message)
+
+ # Process any tool calls
+ complete_conversation = await self._process_conversation(
+ complete_conversation
+ )
+
+ # Extract final text response
+ final_text_parts = [
+ part.text for part in response.parts
+ if isinstance(part, TextPart)
+ ]
+ final_response = " ".join(final_text_parts)
+
+ # Update metadata
+ metadata.update({
+ "usage": {
+ "requests": usage.requests,
+ "request_tokens": usage.request_tokens,
+ "response_tokens": usage.response_tokens,
+ "total_tokens": usage.total_tokens
+ },
+ "attempt": attempt + 1,
+ "tools_used": [
+ part.tool_name for part in response.parts
+ if isinstance(part, ToolCallPart)
+ ]
+ })
+
+ # If we got a response, break the retry loop
+ if final_response:
+ break
+
+ except Exception as e:
+ if attempt == self.config.retry_attempts - 1:
+ raise
+ await asyncio.sleep(self.config.retry_delay)
+
+ if not final_response:
+ raise ValueError("No valid response generated")
- return await self.llm_model.request(message_list, self.model_context)
+ return final_response, metadata
async def stop(self) -> None:
if self.llm_model:
await self.llm_model.close()
- await super().stop()
+ await super().stop()
\ No newline at end of file
diff --git a/bindings/ceylon/ceylon/llm/models/__init__.py b/bindings/ceylon/ceylon/llm/models/__init__.py
index cd45527..8515553 100644
--- a/bindings/ceylon/ceylon/llm/models/__init__.py
+++ b/bindings/ceylon/ceylon/llm/models/__init__.py
@@ -16,6 +16,27 @@
from ceylon.llm.models.support.settings import ModelSettings
from ceylon.llm.models.support.tools import ToolDefinition
from ceylon.llm.models.support.usage import Usage, UsageLimits
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from types import TracebackType
+from typing import AsyncIterator, Optional, Sequence, Type, Any
+import re
+import json
+
+from ceylon.llm.models.support.http import AsyncHTTPClient, cached_async_http_client
+from ceylon.llm.models.support.messages import (
+ ModelMessage,
+ ModelResponse,
+ StreamedResponse,
+ MessageRole,
+ TextPart,
+ ToolCallPart,
+ ToolReturnPart,
+ ModelMessagePart
+)
+from ceylon.llm.models.support.settings import ModelSettings
+from ceylon.llm.models.support.tools import ToolDefinition
+from ceylon.llm.models.support.usage import Usage, UsageLimits
@dataclass
@@ -27,7 +48,13 @@ class ModelContext:
class Model(ABC):
- """Base class for all language model implementations"""
+ """Base class for all language model implementations with tool support"""
+
+ # Regex pattern for extracting tool calls - can be overridden by subclasses
+ TOOL_CALL_PATTERN = re.compile(
+ r'(?P.*?)',
+ re.DOTALL
+ )
def __init__(
self,
@@ -147,11 +174,108 @@ def _check_usage_limits(self, usage: Usage, limits: UsageLimits) -> None:
raise UsageLimitExceeded(
f"Request limit {limits.request_limit} exceeded"
)
- if limits.total_tokens and usage.total_tokens >= limits.total_tokens:
+ if limits.request_tokens_limit and usage.request_tokens >= limits.request_tokens_limit:
+ raise UsageLimitExceeded(
+ f"Request tokens limit {limits.request_tokens_limit} exceeded"
+ )
+ if limits.response_tokens_limit and usage.response_tokens >= limits.response_tokens_limit:
raise UsageLimitExceeded(
- f"Total token limit {limits.total_tokens} exceeded"
+ f"Response tokens limit {limits.response_tokens_limit} exceeded"
)
+ def _format_tool_descriptions(self, tools: Sequence[ToolDefinition]) -> str:
+ """Format tool descriptions for system message.
+
+ Args:
+ tools: Sequence of tool definitions
+
+ Returns:
+ Formatted tool descriptions string
+ """
+ if not tools:
+ return ""
+
+ tool_descriptions = []
+ for tool in tools:
+ desc = f"- {tool.name}: {tool.description}\n"
+ desc += f" Parameters: {json.dumps(tool.parameters_json_schema)}"
+ tool_descriptions.append(desc)
+
+ return (
+ "You have access to the following tools:\n\n"
+ f"{chr(10).join(tool_descriptions)}\n\n"
+ "To use a tool, respond with XML tags like this:\n"
+ "{\"tool_name\": \"tool_name\", \"args\": {\"arg1\": \"value1\"}}\n"
+ "Wait for the tool result before continuing."
+ )
+
+ def _parse_tool_call(self, match: re.Match) -> Optional[ToolCallPart]:
+ """Parse a tool call match into a ToolCallPart.
+
+ Args:
+ match: Regex match object containing tool call JSON
+
+ Returns:
+ ToolCallPart if valid, None if invalid
+ """
+ try:
+ tool_data = json.loads(match.group('tool_json'))
+ if isinstance(tool_data, dict) and 'tool_name' in tool_data and 'args' in tool_data:
+ return ToolCallPart(
+ tool_name=tool_data['tool_name'],
+ args=tool_data['args']
+ )
+ except (json.JSONDecodeError, KeyError):
+ pass
+ return None
+
+ def _parse_response(self, text: str) -> list[ModelMessagePart]:
+ """Parse response text into message parts.
+
+ Args:
+ text: Raw response text from model
+
+ Returns:
+ List of ModelMessagePart objects
+ """
+ parts = []
+ current_text = []
+ last_end = 0
+
+ # Find all tool calls in the response
+ for match in self.TOOL_CALL_PATTERN.finditer(text):
+ # Add any text before the tool call
+ if match.start() > last_end:
+ prefix_text = text[last_end:match.start()].strip()
+ if prefix_text:
+ current_text.append(prefix_text)
+
+ # Parse and add the tool call
+ tool_call = self._parse_tool_call(match)
+ if tool_call:
+ # If we have accumulated text, add it first
+ if current_text:
+ parts.append(TextPart(text=' '.join(current_text)))
+ current_text = []
+ parts.append(tool_call)
+ else:
+ # If tool call parsing fails, treat it as regular text
+ current_text.append(match.group(0))
+
+ last_end = match.end()
+
+ # Add any remaining text after the last tool call
+ if last_end < len(text):
+ remaining = text[last_end:].strip()
+ if remaining:
+ current_text.append(remaining)
+
+ # Add any accumulated text as final part
+ if current_text:
+ parts.append(TextPart(text=' '.join(current_text)))
+
+ return parts
+
class UsageLimitExceeded(Exception):
"""Raised when usage limits are exceeded"""
@@ -178,10 +302,12 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5,
The default timeouts match those of OpenAI,
see .
"""
+
def factory() -> httpx.AsyncClient:
return httpx.AsyncClient(
headers={"User-Agent": get_user_agent()},
timeout=timeout,
base_url=base_url
)
+
return factory
diff --git a/bindings/ceylon/ceylon/llm/models/ollama.py b/bindings/ceylon/ceylon/llm/models/ollama.py
index 1b55a36..5bb0e20 100644
--- a/bindings/ceylon/ceylon/llm/models/ollama.py
+++ b/bindings/ceylon/ceylon/llm/models/ollama.py
@@ -10,7 +10,8 @@
ModelResponse,
StreamedResponse,
TextPart,
- ToolCallPart
+ ToolCallPart,
+ ToolReturnPart
)
from ceylon.llm.models.support.usage import Usage
@@ -26,7 +27,7 @@ def __init__(
**kwargs: Any
):
"""Initialize the Ollama model.
-
+
Args:
model_name: Name of the Ollama model to use
base_url: Base URL for the Ollama API
@@ -41,52 +42,65 @@ async def close(self) -> None:
"""Close the HTTP client"""
await self.client().aclose()
- def _format_messages(self, messages: Sequence[ModelMessage]) -> str:
+ def _format_messages(self, messages: Sequence[ModelMessage], context: ModelContext) -> str:
"""Format messages for Ollama API.
- Ollama's generate endpoint expects a prompt string, so we format
- the messages into a conversation format.
+ Args:
+ messages: Messages to format
+ context: Context containing tools and settings
+
+ Returns:
+ Formatted prompt string
"""
formatted_parts = []
+ # Add tool definitions if available
+ if context.tools:
+ formatted_parts.append(f"System: {self._format_tool_descriptions(context.tools)}")
+
+ # Format conversation history
for message in messages:
- if message.role == MessageRole.SYSTEM:
- formatted_parts.append(f"System: {self._get_text_content(message)}")
- elif message.role == MessageRole.USER:
- formatted_parts.append(f"User: {self._get_text_content(message)}")
- elif message.role == MessageRole.ASSISTANT:
- formatted_parts.append(f"Assistant: {self._get_text_content(message)}")
- elif message.role == MessageRole.TOOL:
- # Include tool name in response if available
- tool_name = self._get_tool_name(message)
- formatted_parts.append(f"Tool ({tool_name}): {self._get_text_content(message)}")
-
- # Add a final Assistant: prompt to indicate it's the model's turn
+ prefix = {
+ MessageRole.SYSTEM: "System",
+ MessageRole.USER: "User",
+ MessageRole.ASSISTANT: "Assistant",
+ MessageRole.TOOL: "Tool"
+ }.get(message.role, "Unknown")
+
+ content = []
+ for part in message.parts:
+ if isinstance(part, TextPart):
+ content.append(part.text)
+ elif isinstance(part, ToolCallPart):
+ tool_call = {
+ "tool_name": part.tool_name,
+ "args": part.args
+ }
+ content.append(
+ f"{json.dumps(tool_call)}"
+ )
+ elif isinstance(part, ToolReturnPart):
+ content.append(
+ f"Result from {part.tool_name}: {json.dumps(part.content)}"
+ )
+
+ formatted_parts.append(f"{prefix}: {' '.join(content)}")
+
+ # Add final prompt for assistant
formatted_parts.append("Assistant:")
return "\n\n".join(formatted_parts)
- def _get_text_content(self, message: ModelMessage) -> str:
- """Extract text content from message parts"""
- parts = []
- for part in message.parts:
- if isinstance(part, TextPart):
- parts.append(part.text)
- return " ".join(parts)
-
- def _get_tool_name(self, message: ModelMessage) -> str:
- """Get tool name from message if present"""
- for part in message.parts:
- if isinstance(part, ToolCallPart):
- return part.tool_name
- return "unknown"
-
- def _prepare_request_data(self, messages: Sequence[ModelMessage], context: ModelContext, stream: bool = False) -> \
- dict[str, Any]:
+ def _prepare_request_data(
+ self,
+ messages: Sequence[ModelMessage],
+ context: ModelContext,
+ stream: bool = False
+ ) -> dict[str, Any]:
"""Prepare the request data for Ollama API"""
data = {
"model": self.model_name,
- "prompt": self._format_messages(messages),
+ "prompt": self._format_messages(messages, context),
"stream": stream
}
@@ -113,7 +127,7 @@ async def request(
Args:
messages: Sequence of messages to send
- context: Context containing settings
+ context: Context containing settings and tools
Returns:
Tuple of (model response, usage statistics)
@@ -130,8 +144,10 @@ async def request(
response.raise_for_status()
result = response.json()
- # Extract response and usage
+ # Parse response and create usage stats
response_text = result.get("response", "")
+ response_parts = self._parse_response(response_text)
+
usage = Usage(
request_tokens=result.get("prompt_eval_count", 0),
response_tokens=result.get("eval_count", 0),
@@ -142,7 +158,7 @@ async def request(
requests=1
)
- return ModelResponse(parts=[TextPart(text=response_text)]), usage
+ return ModelResponse(parts=response_parts), usage
async def request_stream(
self,
@@ -153,7 +169,7 @@ async def request_stream(
Args:
messages: Sequence of messages to send
- context: Context containing settings
+ context: Context containing settings and tools
Yields:
StreamedResponse objects containing response chunks
@@ -162,7 +178,10 @@ async def request_stream(
data = self._prepare_request_data(messages, context, stream=True)
- # Make streaming request
+ # Track state across chunks
+ current_text = []
+ total_usage = Usage(requests=1)
+
async with self.client().stream(
"POST",
"/api/generate",
@@ -170,9 +189,6 @@ async def request_stream(
) as response:
response.raise_for_status()
- # Track usage across chunks
- total_usage = Usage(requests=1)
-
async for line in response.aiter_lines():
if not line.strip():
continue
@@ -182,8 +198,41 @@ async def request_stream(
except json.JSONDecodeError:
continue
- # Extract chunk text and update usage
if "response" in chunk:
+ chunk_text = chunk["response"]
+ current_text.append(chunk_text)
+
+ # Try to parse complete tool calls
+ accumulated_text = ''.join(current_text)
+ for match in self.TOOL_CALL_PATTERN.finditer(accumulated_text):
+ tool_call = self._parse_tool_call(match)
+ if tool_call:
+ # Yield any text before the tool call
+ prefix = accumulated_text[:match.start()].strip()
+ if prefix:
+ yield StreamedResponse(
+ delta=TextPart(text=prefix),
+ usage=None
+ )
+
+ # Yield the tool call
+ yield StreamedResponse(
+ delta=tool_call,
+ usage=None
+ )
+
+ # Keep any remaining text
+ current_text = [accumulated_text[match.end():]]
+ break
+ else:
+ # If no tool call found and we've accumulated enough text
+ if len(accumulated_text) > 100:
+ yield StreamedResponse(
+ delta=TextPart(text=accumulated_text),
+ usage=None
+ )
+ current_text = []
+
# Update usage stats
chunk_usage = Usage(
request_tokens=chunk.get("prompt_eval_count", 0),
@@ -195,14 +244,14 @@ async def request_stream(
)
total_usage.add(chunk_usage)
- # Yield response chunk
- yield StreamedResponse(
- delta=TextPart(text=chunk["response"]),
- usage=chunk_usage
- )
-
- # Handle done message
if chunk.get("done", False):
+ # Yield any remaining text
+ remaining_text = ''.join(current_text).strip()
+ if remaining_text:
+ yield StreamedResponse(
+ delta=TextPart(text=remaining_text),
+ usage=total_usage
+ )
break
@classmethod
@@ -218,4 +267,4 @@ async def list_models(cls, base_url: str = "http://localhost:11434") -> list[dic
async with httpx.AsyncClient(base_url=base_url) as client:
response = await client.get("/api/tags")
response.raise_for_status()
- return response.json().get("models", [])
+ return response.json().get("models", [])
\ No newline at end of file
diff --git a/bindings/ceylon/examples/llm/task_llm_app.py b/bindings/ceylon/examples/llm/task_llm_app.py
index 2be90dd..c8ff697 100644
--- a/bindings/ceylon/examples/llm/task_llm_app.py
+++ b/bindings/ceylon/examples/llm/task_llm_app.py
@@ -2,50 +2,85 @@
from ceylon.llm.agent import LLMConfig, LLMAgent
from ceylon.llm.models.ollama import OllamaModel
-from ceylon.processor.agent import ProcessRequest, ProcessResponse
-from ceylon.processor.playground import ProcessPlayGround
+from ceylon.llm.models.support.tools import ToolDefinition
from ceylon.task.data import Task, TaskResult
from ceylon.task.playground import TaskProcessingPlayground
+async def calculate(expression: str) -> float:
+ print(f"Calculating: {expression}")
+ """Simple calculator function that evaluates mathematical expressions."""
+ try:
+ # Evaluate the expression safely
+ allowed_chars = set("0123456789+-*/(). ")
+ if not all(c in allowed_chars for c in expression):
+ raise ValueError("Invalid characters in expression")
+ return eval(expression)
+ except Exception as e:
+ return f"Error: {str(e)}"
+
+
async def main():
- # Create playground and worker
- playground = TaskProcessingPlayground()
- llm_model = OllamaModel(
- model_name="deepseek-r1:8b",
- base_url="http://localhost:11434"
- )
+ # Define calculator tool
+ tools = [
+ ToolDefinition(
+ name="calculator",
+ description="Performs basic mathematical calculations",
+ parameters_json_schema={
+ "type": "object",
+ "properties": {
+ "expression": {
+ "type": "string",
+ "description": "Mathematical expression to evaluate (e.g., '2 + 2')"
+ }
+ },
+ "required": ["expression"]
+ },
+ function=calculate
+ )
+ ]
- # Configure LLM agent
- llm_config = LLMConfig(
- system_prompt=(
- "You are an expert content writer specializing in technology topics. "
- ),
- temperature=0.7,
- max_tokens=1000,
- retry_attempts=1
- )
+ # Create playground
+ playground = TaskProcessingPlayground()
- # Create LLM agent
+ # Create LLM agent with calculator tool
llm_agent = LLMAgent(
- name="writer_1",
- llm_model=llm_model,
- config=llm_config,
- role="writer"
+ name="math_assistant",
+ llm_model=OllamaModel(
+ model_name="deepseek-r1:8b",
+ base_url="http://localhost:11434"
+ ),
+ config=LLMConfig(
+ system_prompt=(
+ "You are a math assistant. Use the calculator tool to perform calculations. "
+ "When asked to perform calculations, always use the calculator tool for accuracy."
+ ),
+ temperature=0.7,
+ max_tokens=1000,
+ tools=tools
+ ),
+ role="math_assistant"
)
# Start the system
async with playground.play(workers=[llm_agent]) as active_playground:
- active_playground: TaskProcessingPlayground = active_playground
- # Send some test requests
- response: TaskResult = await active_playground.add_and_execute_task(
- Task(
- name="Process Data 1",
- processor="writer",
- input_data={"request": "A Simple title for a blog post about AI"}
+ # Test some calculations
+ test_questions = [
+ "What is 123 * 456?",
+ "Calculate 15.5 + 27.3",
+ "What is (100 - 20) / 2?"
+ ]
+
+ for question in test_questions:
+ response:TaskResult = await active_playground.add_and_execute_task(
+ Task(
+ name="Calculate",
+ processor="math_assistant",
+ input_data={"request": question}
+ )
)
- )
- print(f"Response received: {response.output}")
+ print(f"\nQuestion: {question}")
+ print(f"Response: {response.output}")
if __name__ == "__main__":