diff --git a/server/selection_algorithm.go b/server/selection_algorithm.go index 8ea304002..1a83f1536 100644 --- a/server/selection_algorithm.go +++ b/server/selection_algorithm.go @@ -58,7 +58,7 @@ func (sa ProbabilitySelectionAlgorithm) filterByPerfScore(ctx context.Context, a } func (sa ProbabilitySelectionAlgorithm) filterByMaxPrice(ctx context.Context, addrs []ethcommon.Address, maxPrice *big.Rat, prices map[ethcommon.Address]*big.Rat) []ethcommon.Address { - if maxPrice == nil || len(prices) == 0 { + if maxPrice == nil || maxPrice.Sign() == 0 || len(prices) == 0 { // Max price filter not defined, return all Orchestrators return addrs } diff --git a/server/selection_algorithm_test.go b/server/selection_algorithm_test.go index 61b966b42..5a3d8d0a4 100644 --- a/server/selection_algorithm_test.go +++ b/server/selection_algorithm_test.go @@ -6,6 +6,7 @@ import ( "testing" ethcommon "github.com/ethereum/go-ethereum/common" + "github.com/livepeer/go-livepeer/core" "github.com/stretchr/testify/require" ) @@ -25,6 +26,7 @@ func TestFilter(t *testing.T) { { name: "Some Orchestrators pass the filter", orchMinPerfScore: 0.7, + maxPrice: 0, orchPerfScores: map[string]float64{ "0x0000000000000000000000000000000000000001": 0.6, "0x0000000000000000000000000000000000000002": 0.8, @@ -44,6 +46,7 @@ func TestFilter(t *testing.T) { { name: "No orchestrator Scores defined", orchMinPerfScore: 0.7, + maxPrice: 0, orchPerfScores: nil, orchestrators: []string{ "0x0000000000000000000000000000000000000001", @@ -57,6 +60,7 @@ func TestFilter(t *testing.T) { { name: "No min score defined", orchMinPerfScore: 0, + maxPrice: 0, orchPerfScores: map[string]float64{ "0x0000000000000000000000000000000000000001": 0.6, "0x0000000000000000000000000000000000000002": 0.8, @@ -73,6 +77,7 @@ func TestFilter(t *testing.T) { { name: "No Orchestrators pass the filter", orchMinPerfScore: 0.99, + maxPrice: 0, orchPerfScores: map[string]float64{ "0x0000000000000000000000000000000000000001": 0.6, "0x0000000000000000000000000000000000000002": 0.8, @@ -225,9 +230,80 @@ func TestFilter(t *testing.T) { prices[addr] = new(big.Rat).SetFloat64(price) } } - if tt.maxPrice > 0 { - maxPrice = new(big.Rat).SetFloat64(tt.maxPrice) + + maxPrice = new(big.Rat).SetFloat64(tt.maxPrice) + + sa := &ProbabilitySelectionAlgorithm{ + MinPerfScore: tt.orchMinPerfScore, + IgnoreMaxPriceIfNeeded: tt.ignoreMaxPriceIfNeeded, + } + + res := sa.filter(context.Background(), addrs, maxPrice, prices, perfScores) + + var exp []ethcommon.Address + for _, o := range tt.want { + exp = append(exp, ethcommon.HexToAddress(o)) + } + require.Equal(t, exp, res) + }) + } +} + +func TestNoMaxPriceSet(t *testing.T) { + tests := []struct { + name string + orchMinPerfScore float64 + maxPrice float64 + prices map[string]float64 + orchPerfScores map[string]float64 + orchestrators []string + want []string + ignoreMaxPriceIfNeeded bool + }{ + { + name: "Exact match with max price", + orchMinPerfScore: 0.7, + maxPrice: 0, + prices: map[string]float64{ + "0x0000000000000000000000000000000000000001": 500, + "0x0000000000000000000000000000000000000002": 1000, + "0x0000000000000000000000000000000000000003": 1500, + }, + orchPerfScores: map[string]float64{ + "0x0000000000000000000000000000000000000001": 0.8, + "0x0000000000000000000000000000000000000002": 0.8, + "0x0000000000000000000000000000000000000003": 0.8, + }, + orchestrators: []string{ + "0x0000000000000000000000000000000000000001", + "0x0000000000000000000000000000000000000002", + "0x0000000000000000000000000000000000000003", + }, + want: []string{ + "0x0000000000000000000000000000000000000001", + "0x0000000000000000000000000000000000000002", + "0x0000000000000000000000000000000000000003", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var addrs []ethcommon.Address + var maxPrice *big.Rat + prices := map[ethcommon.Address]*big.Rat{} + perfScores := map[ethcommon.Address]float64{} + for _, o := range tt.orchestrators { + addr := ethcommon.HexToAddress(o) + addrs = append(addrs, addr) + perfScores[addr] = tt.orchPerfScores[o] + if price, ok := tt.prices[o]; ok { + prices[addr] = new(big.Rat).SetFloat64(price) + } } + + fixedPrice := core.NewFixedPrice(big.NewRat(int64(tt.maxPrice), 1)) + maxPrice = fixedPrice.Value() + sa := &ProbabilitySelectionAlgorithm{ MinPerfScore: tt.orchMinPerfScore, IgnoreMaxPriceIfNeeded: tt.ignoreMaxPriceIfNeeded,