Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix codegen issues in preparation for CSE in MTK #1425

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@

SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic()

# HACK:
# A `DestructuredArgs` with `create_bindings = false` doesn't create a `Let` block, and
# instead adds the assignments to the rewrites dictionary. This is problematic, because
# if the `DestructuredArgs` contains a `CallWithMetadata` the key in the `Dict` will be
# a `CallWithMetadata` which won't match against the operation of the called symbolic.
# This is the _only_ hook we have and relies on the `DestructuredArgs` being converted
# into a list of `Assignment`s before being addded to the `Dict` inside `toexpr(::Let, st)`.

Check warning on line 300 in src/variable.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"addded" should be "added".
# The callable symbolic is unwrapped so it matches the operation of the called version.
SymbolicUtils.Code.Assignment(f::CallWithMetadata, x) = SymbolicUtils.Code.Assignment(f.f, x)

function Base.show(io::IO, c::CallWithMetadata)
show(io, c.f)
print(io, "⋆")
Expand Down
17 changes: 16 additions & 1 deletion test/build_function.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Symbolics, SparseArrays, LinearAlgebra, Test
using ReferenceTests
using Symbolics: value
using SymbolicUtils.Code: DestructuredArgs, Func, NameState
using SymbolicUtils.Code: DestructuredArgs, Func, NameState, Let, cse
@variables a b c1 c2 c3 d e g
oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true)
@test all(isnan, eval(oop)([-1, Inf]))
Expand Down Expand Up @@ -301,3 +301,18 @@ end
@test buf ≈ ones(2)
end
end

@testset "cse with arrayops" begin
@variables x[1:3] y f(..)
t = x .+ y
t = t .* f(t)
res = cse(value(t))
@test res isa Let
@test !isempty(res.pairs)
end

@testset "`CallWithMetadata` in `DestructuredArgs` with `create_bindings = false`" begin
@variables x f(..)
fn = build_function(f(x), DestructuredArgs([f]; create_bindings = false), x; expression = Val{false})
@test fn([isodd], 3)
end
Loading