Skip to content

Commit

Permalink
Adaptive Windowing for Multi-Armed Bandits
Browse files Browse the repository at this point in the history
 ### Changes:
 * Added adaptive windowing mechanism to detect and handle concept drift in MAB modelsץ
 * Introduced ActionsManager class to handle action memory and updates with configurable window sizesץ
 * Refactored Model class hierarchy to support model resetting and memory managementץ
 * Added support for infinite and fixed-size windows with change detection via delta parameterץ
 * Enhanced test coverage for adaptive windowing functionality across MAB variantsץ
  • Loading branch information
Shahar-Bar committed Dec 31, 2024
1 parent 64913ef commit 85aba1a
Show file tree
Hide file tree
Showing 14 changed files with 2,199 additions and 589 deletions.
625 changes: 625 additions & 0 deletions pybandits/actions_manager.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
)
ACTION_IDS_PREFIX = "action_ids_"
ACTIONS = "actions"


class _classproperty(property):
Expand Down
79 changes: 21 additions & 58 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,34 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, List, Optional, Set, Union
from typing import Generic, List, Optional, Set, Union

from numpy import array
from numpy.random import choice
from numpy.typing import ArrayLike
from pydantic.generics import GenericModel

from pybandits.actions_manager import CmabActionsManager
from pybandits.base import ActionId, BinaryReward, CmabPredictions
from pybandits.mab import BaseMab
from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC
from pybandits.pydantic_version_compatibility import field_validator, validate_call
from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC, CmabModelType
from pybandits.pydantic_version_compatibility import validate_call
from pybandits.strategy import (
BestActionIdentificationBandit,
ClassicBandit,
CostControlBandit,
StrategyType,
)


class BaseCmabBernoulli(BaseMab):
class BaseCmabBernoulli(BaseMab, GenericModel, Generic[StrategyType, CmabModelType]):
"""
Base model for a Contextual Multi-Armed Bandit for Bernoulli bandits with Thompson Sampling.
Parameters
----------
actions: Dict[ActionId, BayesianLogisticRegression]
The list of possible actions, and their associated Model.
actions_manager: CmabActionsManager[CmabModelType]
The actions manager for handling the actions update and memory.
strategy: Strategy
The strategy used to select actions.
predict_with_proba: bool
Expand All @@ -54,26 +56,9 @@ class BaseCmabBernoulli(BaseMab):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegression]
actions_manager: CmabActionsManager[CmabModelType]
predict_with_proba: bool
predict_actions_randomly: bool

@field_validator("actions", mode="after")
@classmethod
def check_bayesian_logistic_regression_models(cls, v):
action_models = list(v.values())
first_action = action_models[0]
first_action_type = type(first_action)
for action in action_models[1:]:
if not isinstance(action, first_action_type):
raise AttributeError("All actions should follow the same type.")
if not len(action.betas) == len(first_action.betas):
raise AttributeError("All actions should have the same number of betas.")
if not action.update_method == first_action.update_method:
raise AttributeError("All actions should have the same update method.")
if not action.update_kwargs == first_action.update_kwargs:
raise AttributeError("All actions should have the same update kwargs.")
return v
predict_actions_randomly: bool = False

@validate_call(config=dict(arbitrary_types_allowed=True))
def predict(
Expand Down Expand Up @@ -169,26 +154,13 @@ def update(
If strategy is MultiObjectiveBandit, rewards should be a list of list, e.g. (with n_objectives=2):
rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...]
"""
self._validate_update_params(actions=actions, rewards=rewards)
if len(context) != len(rewards):
raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.")

# cast inputs to numpy arrays to facilitate their manipulation
context, actions, rewards = array(context), array(actions), array(rewards)

for a in set(actions):
# get context and rewards of the samples associated to action a
context_of_a = context[actions == a]
rewards_of_a = rewards[actions == a].tolist()

# update model associated to action a
self.actions[a].update(context=context_of_a, rewards=rewards_of_a)
super().update(actions=actions, rewards=rewards, context=context)

# always set predict_actions_randomly after update
self.predict_actions_randomly = False


class CmabBernoulli(BaseCmabBernoulli):
class CmabBernoulli(BaseCmabBernoulli[ClassicBandit, BayesianLogisticRegression]):
"""
Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling.
Expand All @@ -197,8 +169,8 @@ class CmabBernoulli(BaseCmabBernoulli):
Parameters
----------
actions: Dict[ActionId, BayesianLogisticRegression]
The list of possible actions, and their associated Model.
actions_manager: CmabActionsManager[BayesianLogisticRegression]
The actions manager for handling the actions update and memory.
strategy: ClassicBandit
The strategy used to select actions.
predict_with_proba: bool
Expand All @@ -208,13 +180,10 @@ class CmabBernoulli(BaseCmabBernoulli):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegression]
strategy: ClassicBandit
predict_with_proba: bool = False
predict_actions_randomly: bool = False


class CmabBernoulliBAI(BaseCmabBernoulli):
class CmabBernoulliBAI(BaseCmabBernoulli[BestActionIdentificationBandit, BayesianLogisticRegression]):
"""
Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, and Best Action Identification strategy.
Expand All @@ -223,8 +192,8 @@ class CmabBernoulliBAI(BaseCmabBernoulli):
Parameters
----------
actions: Dict[ActionId, BayesianLogisticRegression]
The list of possible actions, and their associated Model.
actions_manager: CmabActionsManager[BayesianLogisticRegression]
The actions manager for handling the actions update and memory.
strategy: BestActionIdentificationBandit
The strategy used to select actions.
predict_with_proba: bool
Expand All @@ -234,13 +203,10 @@ class CmabBernoulliBAI(BaseCmabBernoulli):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegression]
strategy: BestActionIdentificationBandit
predict_with_proba: bool = False
predict_actions_randomly: bool = False


class CmabBernoulliCC(BaseCmabBernoulli):
class CmabBernoulliCC(BaseCmabBernoulli[CostControlBandit, BayesianLogisticRegressionCC]):
"""
Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, and Cost Control strategy.
Expand All @@ -257,8 +223,8 @@ class CmabBernoulliCC(BaseCmabBernoulli):
Parameters
----------
actions: Dict[ActionId, BayesianLogisticRegressionCC]
The list of possible actions, and their associated Model.
actions_manager: CmabActionsManager[BayesianLogisticRegressionCC]
The actions manager for handling the actions update and memory.
strategy: CostControlBandit
The strategy used to select actions.
predict_with_proba: bool
Expand All @@ -268,7 +234,4 @@ class CmabBernoulliCC(BaseCmabBernoulli):
bandit strategy.
"""

actions: Dict[ActionId, BayesianLogisticRegressionCC]
strategy: CostControlBandit
predict_with_proba: bool = True
predict_actions_randomly: bool = False
Loading

0 comments on commit 85aba1a

Please sign in to comment.