Skip to content

Commit

Permalink
Merge pull request #1398 from AayushSabharwal/as/fast-substitute-array
Browse files Browse the repository at this point in the history
fix: fix `fast_substitute` folding array of symbolics
  • Loading branch information
ChrisRackauckas authored Jan 16, 2025
2 parents eb3b5f6 + 9271d62 commit 3d8ce96
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ StaticArraysCore = "1.4"
SymPy = "2.2"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.2"
SymbolicUtils = "3.7"
SymbolicUtils = "3.10"
TermInterface = "2"
julia = "1.10"

Expand Down
4 changes: 2 additions & 2 deletions src/solver/preprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function _filter_poly(expr, var)
subs[i_var] = im
expr = unwrap(expr1 + i_var * expr2)

args = arguments(expr)
args = map(unwrap, arguments(expr))
oper = operation(expr)
return subs, term(oper, args...)
end
Expand Down Expand Up @@ -208,7 +208,7 @@ function _filter_poly(expr, var)
end
end

args = arguments(expr)
args = map(unwrap, arguments(expr))
oper = operation(expr)
expr = term(oper, args...)
return subs, expr
Expand Down
11 changes: 9 additions & 2 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing)
args = let canfold = canfold
map(args) do x
x′ = fast_substitute(x, subs; operator)
canfold[] = canfold[] && !(x′ isa Symbolic)
canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′))
x′
end
end
Expand All @@ -633,7 +633,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
args = let canfold = canfold
map(args) do x
x′ = fast_substitute(x, pair; operator)
canfold[] = canfold[] && !(x′ isa Symbolic)
canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′))
x′
end
end
Expand All @@ -645,6 +645,13 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
metadata(expr))
end

function is_array_of_symbolics(x)
symbolic_type(x) == ArraySymbolic() && return true
symbolic_type(x) == ScalarSymbolic() && return false
x isa AbstractArray &&
any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x)
end

function getparent(x, val=_fail)
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
if maybe_parent !== nothing
Expand Down
2 changes: 1 addition & 1 deletion test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ end
lapu = wrap(lapu)
lapv = wrap(lapv)

f, g = build_function(dtu, u, v, t, expression=Val{false})
f, g = build_function(dtu, u, v, t, expression=Val{false}, nanmath = false)
du = zeros(Num, 8, 8)
f(du, u,v,t)
@test isequal(collect(du), collect(dtu))
Expand Down
10 changes: 9 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,12 @@ end
test_nested_derivative = Dx(Dt(Dt(u)))
result = diff2term(Symbolics.value(test_nested_derivative))
@test typeof(result) === Symbolics.BasicSymbolic{Real}
end
end

@testset "`fast_substitute` inside array symbolics" begin
@variables x y z
@register_symbolic foo(a::AbstractArray, b)
ex = foo([x, y], z)
ex2 = Symbolics.fixpoint_sub(ex, Dict(y => 1.0, z => 2.0))
@test isequal(ex2, foo([x, 1.0], 2.0))
end

0 comments on commit 3d8ce96

Please sign in to comment.