Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[simd] Simplify SIMD.reduce_op functions #3083

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 50 additions & 58 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2131,13 +2131,45 @@ struct SIMD[type: DType, size: Int](
# Reduce operations
# ===------------------------------------------------------------------=== #

alias _T = SIMD[type, _]

# TODO: remove when non-capturing can be converted to capturing.
@always_inline
fn reduce[
func: fn[type: DType, width: Int] (
SIMD[type, width], SIMD[type, width]
) capturing -> SIMD[type, width],
func: fn[width: Int] (
Self._T[width],
Self._T[width],
) -> Self._T[width],
size_out: Int = 1,
](self) -> SIMD[type, size_out]:
](self) -> Self._T[size_out]:
"""Reduces the vector using a provided reduce operator.

Parameters:
func: The reduce function to apply to elements in this SIMD.
size_out: The width of the reduction.

Constraints:
`size_out` must not exceed width of the vector.

Returns:
A new scalar which is the reduction of all vector elements.
"""

@always_inline
@parameter
fn body[w: Int](lhs: Self._T[w], rhs: Self._T[w]) -> Self._T[w]:
return func(lhs, rhs)

return self.reduce[body, size_out]()

@always_inline
fn reduce[
func: fn[width: Int] (
Self._T[width],
Self._T[width],
) capturing -> Self._T[width],
size_out: Int = 1,
](self) -> Self._T[size_out]:
"""Reduces the vector using a provided reduce operator.

Parameters:
Expand All @@ -2154,15 +2186,15 @@ struct SIMD[type: DType, size: Int](

@parameter
if size == size_out:
return rebind[SIMD[type, size_out]](self)
return rebind[Self._T[size_out]](self)
else:
var lhs: Self._SIMDHalfType
var rhs: Self._SIMDHalfType
lhs, rhs = self.split()
return func(lhs, rhs).reduce[func, size_out]()

@always_inline("nodebug")
fn reduce_max[size_out: Int = 1](self) -> SIMD[type, size_out]:
fn reduce_max[size_out: Int = 1](self) -> Self._T[size_out]:
"""Reduces the vector using the `max` operator.

Parameters:
Expand All @@ -2185,14 +2217,12 @@ struct SIMD[type: DType, size: Int](

@always_inline
@parameter
fn max_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
fn body[
width: Int
](v1: Self._T[width], v2: Self._T[width]) -> Self._T[width]:
return max(v1, v2)

return self.reduce[max_reduce_body, size_out]()
return self.reduce[body, size_out]()

@parameter
if type.is_floating_point():
Expand Down Expand Up @@ -2243,14 +2273,12 @@ struct SIMD[type: DType, size: Int](

@always_inline
@parameter
fn min_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
fn body[
width: Int
](v1: Self._T[width], v2: Self._T[width]) -> Self._T[width]:
return min(v1, v2)

return self.reduce[min_reduce_body, size_out]()
return self.reduce[body, size_out]()

@parameter
if type.is_floating_point():
Expand Down Expand Up @@ -2291,15 +2319,7 @@ struct SIMD[type: DType, size: Int](
The sum of all vector elements.

"""

@always_inline
@parameter
fn add_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]:
return v1 + v2

return self.reduce[add_reduce_body, size_out]()
return self.reduce[Self._T.__add__, size_out]()

@always_inline
fn reduce_mul[size_out: Int = 1](self) -> SIMD[type, size_out]:
Expand All @@ -2315,15 +2335,7 @@ struct SIMD[type: DType, size: Int](
Returns:
The product of all vector elements.
"""

@always_inline
@parameter
fn mul_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[type, width]:
return v1 * v2

return self.reduce[mul_reduce_body, size_out]()
return self.reduce[Self._T.__mul__, size_out]()

@always_inline
fn reduce_and[size_out: Int = 1](self) -> SIMD[type, size_out]:
Expand All @@ -2349,17 +2361,7 @@ struct SIMD[type: DType, size: Int](

@parameter
if size_out > 1:

@always_inline
@parameter
fn and_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1 & v2

return self.reduce[and_reduce_body, size_out]()
return self.reduce[Self._T.__and__, size_out]()

@parameter
if size == 1:
Expand Down Expand Up @@ -2395,17 +2397,7 @@ struct SIMD[type: DType, size: Int](

@parameter
if size_out > 1:

@always_inline
@parameter
fn or_reduce_body[
type: DType, width: Int
](v1: SIMD[type, width], v2: SIMD[type, width]) -> SIMD[
type, width
]:
return v1 | v2

return self.reduce[or_reduce_body, size_out]()
return self.reduce[Self._T.__or__, size_out]()

@parameter
if size == 1:
Expand Down