diff --git a/bindings/ceylon/ceylon/llm/agent.py b/bindings/ceylon/ceylon/llm/agent.py index fec2983..4fe573e 100644 --- a/bindings/ceylon/ceylon/llm/agent.py +++ b/bindings/ceylon/ceylon/llm/agent.py @@ -5,12 +5,19 @@ import asyncio from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, Any +from typing import Dict, Any, Optional, List, Sequence from pydantic import BaseModel from ceylon.llm.models import Model, ModelSettings, ModelMessage -from ceylon.llm.models.support.messages import MessageRole, TextPart +from ceylon.llm.models.support.messages import ( + MessageRole, + TextPart, + ToolCallPart, + ToolReturnPart, + ModelMessagePart +) +from ceylon.llm.models.support.tools import ToolDefinition from ceylon.processor.agent import ProcessWorker from ceylon.processor.data import ProcessRequest @@ -30,6 +37,8 @@ class LLMConfig(BaseModel): retry_attempts: int = 3 retry_delay: float = 1.0 timeout: float = 30.0 + tools: Optional[Sequence[ToolDefinition]] = None + parallel_tool_calls: Optional[int] = None class Config: arbitrary_types_allowed = True @@ -58,33 +67,175 @@ def __init__( self.response_cache: Dict[str, LLMResponse] = {} self.processing_lock = asyncio.Lock() - # Initialize model context with settings + # Initialize model context with settings and tools self.model_context = self.llm_model.create_context( settings=ModelSettings( temperature=config.temperature, - max_tokens=config.max_tokens - ) + max_tokens=config.max_tokens, + parallel_tool_calls=config.parallel_tool_calls + ), + tools=config.tools or [] ) - async def _processor(self, request: ProcessRequest, time: int): + async def _process_tool_calls( + self, + message_parts: List[ModelMessagePart] + ) -> List[ModelMessagePart]: + """Process any tool calls in the message parts and return updated parts.""" + processed_parts = [] + + for part in message_parts: + if isinstance(part, ToolCallPart): + try: + # Find the corresponding tool + tool = next( + (t for t in self.config.tools or [] + if t.name == part.tool_name), + None + ) + + if tool: + # Execute the tool + result = await tool.function(**part.args) + + # Add the tool return + processed_parts.append( + ToolReturnPart( + tool_name=part.tool_name, + content=result + ) + ) + else: + # Tool not found - add error message + processed_parts.append( + TextPart( + text=f"Error: Tool '{part.tool_name}' not found" + ) + ) + except Exception as e: + # Handle tool execution error + processed_parts.append( + TextPart( + text=f"Error executing tool '{part.tool_name}': {str(e)}" + ) + ) + else: + processed_parts.append(part) + + return processed_parts + + async def _process_conversation( + self, + messages: List[ModelMessage] + ) -> List[ModelMessage]: + """Process a conversation, handling tool calls as needed.""" + processed_messages = [] + + for message in messages: + if message.role == MessageRole.ASSISTANT: + # Process any tool calls in assistant messages + processed_parts = await self._process_tool_calls(message.parts) + processed_messages.append( + ModelMessage( + role=message.role, + parts=processed_parts + ) + ) + else: + processed_messages.append(message) + + return processed_messages + + def _parse_request_data(self, data: Any) -> str: + """Parse the request data into a string format.""" + if isinstance(data, str): + return data + elif isinstance(data, dict): + return data.get("request", str(data)) + else: + return str(data) + + async def _processor(self, request: ProcessRequest, time: int) -> tuple[str, Dict[str, Any]]: + """Process a request using the LLM model.""" + # Initialize conversation with system prompt message_list = [ ModelMessage( role=MessageRole.SYSTEM, - parts=[ - TextPart(text=self.config.system_prompt) - ] - ), + parts=[TextPart(text=self.config.system_prompt)] + ) + ] + + # Add user message + user_text = self._parse_request_data(request.data) + message_list.append( ModelMessage( role=MessageRole.USER, - parts=[ - TextPart(text=request.data) - ] + parts=[TextPart(text=user_text)] ) - ] + ) + + # Track the complete conversation + complete_conversation = message_list.copy() + final_response = None + metadata = {} + + for attempt in range(self.config.retry_attempts): + try: + # Get model response + response, usage = await self.llm_model.request( + message_list, + self.model_context + ) + + # Add model response to conversation + assistant_message = ModelMessage( + role=MessageRole.ASSISTANT, + parts=response.parts + ) + complete_conversation.append(assistant_message) + + # Process any tool calls + complete_conversation = await self._process_conversation( + complete_conversation + ) + + # Extract final text response + final_text_parts = [ + part.text for part in response.parts + if isinstance(part, TextPart) + ] + final_response = " ".join(final_text_parts) + + # Update metadata + metadata.update({ + "usage": { + "requests": usage.requests, + "request_tokens": usage.request_tokens, + "response_tokens": usage.response_tokens, + "total_tokens": usage.total_tokens + }, + "attempt": attempt + 1, + "tools_used": [ + part.tool_name for part in response.parts + if isinstance(part, ToolCallPart) + ] + }) + + # If we got a response, break the retry loop + if final_response: + break + + except Exception as e: + if attempt == self.config.retry_attempts - 1: + raise + await asyncio.sleep(self.config.retry_delay) + + if not final_response: + raise ValueError("No valid response generated") - return await self.llm_model.request(message_list, self.model_context) + return final_response, metadata async def stop(self) -> None: if self.llm_model: await self.llm_model.close() - await super().stop() + await super().stop() \ No newline at end of file diff --git a/bindings/ceylon/ceylon/llm/models/__init__.py b/bindings/ceylon/ceylon/llm/models/__init__.py index cd45527..8515553 100644 --- a/bindings/ceylon/ceylon/llm/models/__init__.py +++ b/bindings/ceylon/ceylon/llm/models/__init__.py @@ -16,6 +16,27 @@ from ceylon.llm.models.support.settings import ModelSettings from ceylon.llm.models.support.tools import ToolDefinition from ceylon.llm.models.support.usage import Usage, UsageLimits +from abc import ABC, abstractmethod +from dataclasses import dataclass +from types import TracebackType +from typing import AsyncIterator, Optional, Sequence, Type, Any +import re +import json + +from ceylon.llm.models.support.http import AsyncHTTPClient, cached_async_http_client +from ceylon.llm.models.support.messages import ( + ModelMessage, + ModelResponse, + StreamedResponse, + MessageRole, + TextPart, + ToolCallPart, + ToolReturnPart, + ModelMessagePart +) +from ceylon.llm.models.support.settings import ModelSettings +from ceylon.llm.models.support.tools import ToolDefinition +from ceylon.llm.models.support.usage import Usage, UsageLimits @dataclass @@ -27,7 +48,13 @@ class ModelContext: class Model(ABC): - """Base class for all language model implementations""" + """Base class for all language model implementations with tool support""" + + # Regex pattern for extracting tool calls - can be overridden by subclasses + TOOL_CALL_PATTERN = re.compile( + r'(?P.*?)', + re.DOTALL + ) def __init__( self, @@ -147,11 +174,108 @@ def _check_usage_limits(self, usage: Usage, limits: UsageLimits) -> None: raise UsageLimitExceeded( f"Request limit {limits.request_limit} exceeded" ) - if limits.total_tokens and usage.total_tokens >= limits.total_tokens: + if limits.request_tokens_limit and usage.request_tokens >= limits.request_tokens_limit: + raise UsageLimitExceeded( + f"Request tokens limit {limits.request_tokens_limit} exceeded" + ) + if limits.response_tokens_limit and usage.response_tokens >= limits.response_tokens_limit: raise UsageLimitExceeded( - f"Total token limit {limits.total_tokens} exceeded" + f"Response tokens limit {limits.response_tokens_limit} exceeded" ) + def _format_tool_descriptions(self, tools: Sequence[ToolDefinition]) -> str: + """Format tool descriptions for system message. + + Args: + tools: Sequence of tool definitions + + Returns: + Formatted tool descriptions string + """ + if not tools: + return "" + + tool_descriptions = [] + for tool in tools: + desc = f"- {tool.name}: {tool.description}\n" + desc += f" Parameters: {json.dumps(tool.parameters_json_schema)}" + tool_descriptions.append(desc) + + return ( + "You have access to the following tools:\n\n" + f"{chr(10).join(tool_descriptions)}\n\n" + "To use a tool, respond with XML tags like this:\n" + "{\"tool_name\": \"tool_name\", \"args\": {\"arg1\": \"value1\"}}\n" + "Wait for the tool result before continuing." + ) + + def _parse_tool_call(self, match: re.Match) -> Optional[ToolCallPart]: + """Parse a tool call match into a ToolCallPart. + + Args: + match: Regex match object containing tool call JSON + + Returns: + ToolCallPart if valid, None if invalid + """ + try: + tool_data = json.loads(match.group('tool_json')) + if isinstance(tool_data, dict) and 'tool_name' in tool_data and 'args' in tool_data: + return ToolCallPart( + tool_name=tool_data['tool_name'], + args=tool_data['args'] + ) + except (json.JSONDecodeError, KeyError): + pass + return None + + def _parse_response(self, text: str) -> list[ModelMessagePart]: + """Parse response text into message parts. + + Args: + text: Raw response text from model + + Returns: + List of ModelMessagePart objects + """ + parts = [] + current_text = [] + last_end = 0 + + # Find all tool calls in the response + for match in self.TOOL_CALL_PATTERN.finditer(text): + # Add any text before the tool call + if match.start() > last_end: + prefix_text = text[last_end:match.start()].strip() + if prefix_text: + current_text.append(prefix_text) + + # Parse and add the tool call + tool_call = self._parse_tool_call(match) + if tool_call: + # If we have accumulated text, add it first + if current_text: + parts.append(TextPart(text=' '.join(current_text))) + current_text = [] + parts.append(tool_call) + else: + # If tool call parsing fails, treat it as regular text + current_text.append(match.group(0)) + + last_end = match.end() + + # Add any remaining text after the last tool call + if last_end < len(text): + remaining = text[last_end:].strip() + if remaining: + current_text.append(remaining) + + # Add any accumulated text as final part + if current_text: + parts.append(TextPart(text=' '.join(current_text))) + + return parts + class UsageLimitExceeded(Exception): """Raised when usage limits are exceeded""" @@ -178,10 +302,12 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5, The default timeouts match those of OpenAI, see . """ + def factory() -> httpx.AsyncClient: return httpx.AsyncClient( headers={"User-Agent": get_user_agent()}, timeout=timeout, base_url=base_url ) + return factory diff --git a/bindings/ceylon/ceylon/llm/models/ollama.py b/bindings/ceylon/ceylon/llm/models/ollama.py index 1b55a36..5bb0e20 100644 --- a/bindings/ceylon/ceylon/llm/models/ollama.py +++ b/bindings/ceylon/ceylon/llm/models/ollama.py @@ -10,7 +10,8 @@ ModelResponse, StreamedResponse, TextPart, - ToolCallPart + ToolCallPart, + ToolReturnPart ) from ceylon.llm.models.support.usage import Usage @@ -26,7 +27,7 @@ def __init__( **kwargs: Any ): """Initialize the Ollama model. - + Args: model_name: Name of the Ollama model to use base_url: Base URL for the Ollama API @@ -41,52 +42,65 @@ async def close(self) -> None: """Close the HTTP client""" await self.client().aclose() - def _format_messages(self, messages: Sequence[ModelMessage]) -> str: + def _format_messages(self, messages: Sequence[ModelMessage], context: ModelContext) -> str: """Format messages for Ollama API. - Ollama's generate endpoint expects a prompt string, so we format - the messages into a conversation format. + Args: + messages: Messages to format + context: Context containing tools and settings + + Returns: + Formatted prompt string """ formatted_parts = [] + # Add tool definitions if available + if context.tools: + formatted_parts.append(f"System: {self._format_tool_descriptions(context.tools)}") + + # Format conversation history for message in messages: - if message.role == MessageRole.SYSTEM: - formatted_parts.append(f"System: {self._get_text_content(message)}") - elif message.role == MessageRole.USER: - formatted_parts.append(f"User: {self._get_text_content(message)}") - elif message.role == MessageRole.ASSISTANT: - formatted_parts.append(f"Assistant: {self._get_text_content(message)}") - elif message.role == MessageRole.TOOL: - # Include tool name in response if available - tool_name = self._get_tool_name(message) - formatted_parts.append(f"Tool ({tool_name}): {self._get_text_content(message)}") - - # Add a final Assistant: prompt to indicate it's the model's turn + prefix = { + MessageRole.SYSTEM: "System", + MessageRole.USER: "User", + MessageRole.ASSISTANT: "Assistant", + MessageRole.TOOL: "Tool" + }.get(message.role, "Unknown") + + content = [] + for part in message.parts: + if isinstance(part, TextPart): + content.append(part.text) + elif isinstance(part, ToolCallPart): + tool_call = { + "tool_name": part.tool_name, + "args": part.args + } + content.append( + f"{json.dumps(tool_call)}" + ) + elif isinstance(part, ToolReturnPart): + content.append( + f"Result from {part.tool_name}: {json.dumps(part.content)}" + ) + + formatted_parts.append(f"{prefix}: {' '.join(content)}") + + # Add final prompt for assistant formatted_parts.append("Assistant:") return "\n\n".join(formatted_parts) - def _get_text_content(self, message: ModelMessage) -> str: - """Extract text content from message parts""" - parts = [] - for part in message.parts: - if isinstance(part, TextPart): - parts.append(part.text) - return " ".join(parts) - - def _get_tool_name(self, message: ModelMessage) -> str: - """Get tool name from message if present""" - for part in message.parts: - if isinstance(part, ToolCallPart): - return part.tool_name - return "unknown" - - def _prepare_request_data(self, messages: Sequence[ModelMessage], context: ModelContext, stream: bool = False) -> \ - dict[str, Any]: + def _prepare_request_data( + self, + messages: Sequence[ModelMessage], + context: ModelContext, + stream: bool = False + ) -> dict[str, Any]: """Prepare the request data for Ollama API""" data = { "model": self.model_name, - "prompt": self._format_messages(messages), + "prompt": self._format_messages(messages, context), "stream": stream } @@ -113,7 +127,7 @@ async def request( Args: messages: Sequence of messages to send - context: Context containing settings + context: Context containing settings and tools Returns: Tuple of (model response, usage statistics) @@ -130,8 +144,10 @@ async def request( response.raise_for_status() result = response.json() - # Extract response and usage + # Parse response and create usage stats response_text = result.get("response", "") + response_parts = self._parse_response(response_text) + usage = Usage( request_tokens=result.get("prompt_eval_count", 0), response_tokens=result.get("eval_count", 0), @@ -142,7 +158,7 @@ async def request( requests=1 ) - return ModelResponse(parts=[TextPart(text=response_text)]), usage + return ModelResponse(parts=response_parts), usage async def request_stream( self, @@ -153,7 +169,7 @@ async def request_stream( Args: messages: Sequence of messages to send - context: Context containing settings + context: Context containing settings and tools Yields: StreamedResponse objects containing response chunks @@ -162,7 +178,10 @@ async def request_stream( data = self._prepare_request_data(messages, context, stream=True) - # Make streaming request + # Track state across chunks + current_text = [] + total_usage = Usage(requests=1) + async with self.client().stream( "POST", "/api/generate", @@ -170,9 +189,6 @@ async def request_stream( ) as response: response.raise_for_status() - # Track usage across chunks - total_usage = Usage(requests=1) - async for line in response.aiter_lines(): if not line.strip(): continue @@ -182,8 +198,41 @@ async def request_stream( except json.JSONDecodeError: continue - # Extract chunk text and update usage if "response" in chunk: + chunk_text = chunk["response"] + current_text.append(chunk_text) + + # Try to parse complete tool calls + accumulated_text = ''.join(current_text) + for match in self.TOOL_CALL_PATTERN.finditer(accumulated_text): + tool_call = self._parse_tool_call(match) + if tool_call: + # Yield any text before the tool call + prefix = accumulated_text[:match.start()].strip() + if prefix: + yield StreamedResponse( + delta=TextPart(text=prefix), + usage=None + ) + + # Yield the tool call + yield StreamedResponse( + delta=tool_call, + usage=None + ) + + # Keep any remaining text + current_text = [accumulated_text[match.end():]] + break + else: + # If no tool call found and we've accumulated enough text + if len(accumulated_text) > 100: + yield StreamedResponse( + delta=TextPart(text=accumulated_text), + usage=None + ) + current_text = [] + # Update usage stats chunk_usage = Usage( request_tokens=chunk.get("prompt_eval_count", 0), @@ -195,14 +244,14 @@ async def request_stream( ) total_usage.add(chunk_usage) - # Yield response chunk - yield StreamedResponse( - delta=TextPart(text=chunk["response"]), - usage=chunk_usage - ) - - # Handle done message if chunk.get("done", False): + # Yield any remaining text + remaining_text = ''.join(current_text).strip() + if remaining_text: + yield StreamedResponse( + delta=TextPart(text=remaining_text), + usage=total_usage + ) break @classmethod @@ -218,4 +267,4 @@ async def list_models(cls, base_url: str = "http://localhost:11434") -> list[dic async with httpx.AsyncClient(base_url=base_url) as client: response = await client.get("/api/tags") response.raise_for_status() - return response.json().get("models", []) + return response.json().get("models", []) \ No newline at end of file diff --git a/bindings/ceylon/examples/llm/task_llm_app.py b/bindings/ceylon/examples/llm/task_llm_app.py index 2be90dd..c8ff697 100644 --- a/bindings/ceylon/examples/llm/task_llm_app.py +++ b/bindings/ceylon/examples/llm/task_llm_app.py @@ -2,50 +2,85 @@ from ceylon.llm.agent import LLMConfig, LLMAgent from ceylon.llm.models.ollama import OllamaModel -from ceylon.processor.agent import ProcessRequest, ProcessResponse -from ceylon.processor.playground import ProcessPlayGround +from ceylon.llm.models.support.tools import ToolDefinition from ceylon.task.data import Task, TaskResult from ceylon.task.playground import TaskProcessingPlayground +async def calculate(expression: str) -> float: + print(f"Calculating: {expression}") + """Simple calculator function that evaluates mathematical expressions.""" + try: + # Evaluate the expression safely + allowed_chars = set("0123456789+-*/(). ") + if not all(c in allowed_chars for c in expression): + raise ValueError("Invalid characters in expression") + return eval(expression) + except Exception as e: + return f"Error: {str(e)}" + + async def main(): - # Create playground and worker - playground = TaskProcessingPlayground() - llm_model = OllamaModel( - model_name="deepseek-r1:8b", - base_url="http://localhost:11434" - ) + # Define calculator tool + tools = [ + ToolDefinition( + name="calculator", + description="Performs basic mathematical calculations", + parameters_json_schema={ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate (e.g., '2 + 2')" + } + }, + "required": ["expression"] + }, + function=calculate + ) + ] - # Configure LLM agent - llm_config = LLMConfig( - system_prompt=( - "You are an expert content writer specializing in technology topics. " - ), - temperature=0.7, - max_tokens=1000, - retry_attempts=1 - ) + # Create playground + playground = TaskProcessingPlayground() - # Create LLM agent + # Create LLM agent with calculator tool llm_agent = LLMAgent( - name="writer_1", - llm_model=llm_model, - config=llm_config, - role="writer" + name="math_assistant", + llm_model=OllamaModel( + model_name="deepseek-r1:8b", + base_url="http://localhost:11434" + ), + config=LLMConfig( + system_prompt=( + "You are a math assistant. Use the calculator tool to perform calculations. " + "When asked to perform calculations, always use the calculator tool for accuracy." + ), + temperature=0.7, + max_tokens=1000, + tools=tools + ), + role="math_assistant" ) # Start the system async with playground.play(workers=[llm_agent]) as active_playground: - active_playground: TaskProcessingPlayground = active_playground - # Send some test requests - response: TaskResult = await active_playground.add_and_execute_task( - Task( - name="Process Data 1", - processor="writer", - input_data={"request": "A Simple title for a blog post about AI"} + # Test some calculations + test_questions = [ + "What is 123 * 456?", + "Calculate 15.5 + 27.3", + "What is (100 - 20) / 2?" + ] + + for question in test_questions: + response:TaskResult = await active_playground.add_and_execute_task( + Task( + name="Calculate", + processor="math_assistant", + input_data={"request": question} + ) ) - ) - print(f"Response received: {response.output}") + print(f"\nQuestion: {question}") + print(f"Response: {response.output}") if __name__ == "__main__":