Skip to content

Commit

Permalink
Add LLM Task Manger
Browse files Browse the repository at this point in the history
  • Loading branch information
dewmal committed Feb 2, 2025
1 parent be6dff6 commit afec1bb
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 21 deletions.
11 changes: 9 additions & 2 deletions bindings/ceylon/ceylon/base/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, name="playground", port=8888):
self.llm_agents: Dict[str, AgentConnectedStatus] = {}
self._connected_event = None
self._stop_event = None
self._all_tasks_completed_events: Dict[str, asyncio.Event] = {}
self._completed_tasks: Dict[str, TaskOutput] = {}
self._task_results: Dict[str, Any] = {}

Expand All @@ -55,6 +56,11 @@ def get_completed_tasks(self) -> Dict[str, TaskOutput]:
"""Get all completed tasks"""
return self._completed_tasks.copy()

async def wait_and_get_completed_tasks(self) -> Dict[str, TaskOutput]:
for event in self._all_tasks_completed_events.values():
await event.wait()
return self.get_completed_tasks()

def get_task_results(self) -> Dict[str, Any]:
"""Get all task results"""
return self._task_results.copy()
Expand Down Expand Up @@ -109,11 +115,12 @@ async def play(self, workers: Optional[List[BaseAgent]] = None):
for task_id, output in self._completed_tasks.items():
if output.completed:
duration = output.end_time - output.start_time if output.end_time and output.start_time else None
logger.info(f"Task {task_id} ({output.name}) completed in {duration:.2f}s" if duration else f"Task {task_id} ({output.name}) completed")
logger.info(
f"Task {task_id} ({output.name}) completed in {duration:.2f}s" if duration else f"Task {task_id} ({output.name}) completed")
else:
logger.warning(f"Task {task_id} ({output.name}) failed: {output.error}")

# Cleanup
self._connected_event = None
self._stop_event = None
await self.stop()
await self.stop()
16 changes: 8 additions & 8 deletions bindings/ceylon/ceylon/llm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ async def execute_task(self, task: TaskMessage) -> None:
Execute an LLM task with retry logic and error handling
"""
try:
print(f"\n{'='*80}")
print(f"Task: {task.name}")
print(f"Description: {task.description}")
print(f"{'='*80}\n")
logger.info(f"\n{'='*80}")
logger.info(f"Task: {task.name}")
logger.info(f"Description: {task.description}")
logger.info(f"{'='*80}\n")

async with self.processing_lock:
response = await self._execute_with_retry(task)
Expand All @@ -83,10 +83,10 @@ async def execute_task(self, task: TaskMessage) -> None:
self.response_cache[task.task_id] = response

# Print the response
print("\nGenerated Content:")
print(f"{'-'*80}")
print(response.content)
print(f"{'-'*80}\n")
logger.info("\nGenerated Content:")
logger.info(f"{'-'*80}")
logger.info(response.content)
logger.info(f"{'-'*80}\n")

# Update task with completion info
task.completed = True
Expand Down
4 changes: 4 additions & 0 deletions bindings/ceylon/ceylon/task/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Licensed under the Apache License, Version 2.0 (See LICENSE.md or http://www.apache.org/licenses/LICENSE-2.0).
#
#
import asyncio

from loguru import logger
from ceylon import on, on_connect
from ceylon.base.playground import BasePlayGround, TaskOutput
Expand Down Expand Up @@ -34,6 +36,7 @@ async def handle_task_completion(self, task: TaskMessage, time: int):
if hasattr(task, 'result'):
self.add_task_result(task.task_id, task.result)

self._all_tasks_completed_events[task.task_id].set()
# Broadcast status update
await self.broadcast_message(TaskStatusUpdate(
task_id=task.task_id,
Expand Down Expand Up @@ -97,6 +100,7 @@ async def assign_task_groups(self, groups):
assignments = await self.task_manager.assign_task_groups(groups)
for assignment in assignments:
await self.broadcast_message(assignment)
self._all_tasks_completed_events[assignment.task_id] = asyncio.Event()

async def print_all_statistics(self):
"""Print statistics for all task groups"""
Expand Down
10 changes: 6 additions & 4 deletions bindings/ceylon/examples/llm/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import uuid
from typing import List, Dict
from datetime import datetime
from typing import List, Dict

from ceylon.llm.agent import LLMAgent, LLMConfig
from ceylon.llm.models.ollama import OllamaModel
from ceylon.task import TaskPlayGround
from ceylon.task.data import TaskMessage, TaskGroup, TaskGroupGoal, GoalStatus
from ceylon.task.data import TaskMessage, TaskGroupGoal, GoalStatus
from ceylon.task.manager import TaskManager
from ceylon.llm.models.ollama import OllamaModel
from ceylon.llm.agent import LLMAgent, LLMConfig


def print_header(text: str):
print(f"\n{'='*80}")
Expand Down
8 changes: 1 addition & 7 deletions bindings/ceylon/examples/llm/simple_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,9 @@ async def main():
async with playground.play(workers=[llm_agent]) as active_playground:
# Assign task group
await active_playground.assign_task_groups([task_group])

# Wait for completion
while True:
await asyncio.sleep(1)
if task.task_id in active_playground.get_completed_tasks():
break

# Get and display results
completed_task = active_playground.get_completed_tasks()[task.task_id]
completed_task = (await active_playground.wait_and_get_completed_tasks())[task.task_id]
if completed_task.completed:
print("\nTask Completed Successfully!")
print(f"Duration: {completed_task.end_time - completed_task.start_time:.2f}s")
Expand Down

0 comments on commit afec1bb

Please sign in to comment.