Skip to content

Commit

Permalink
Improve standard simplify rule matches in non-commutative contexts (#…
Browse files Browse the repository at this point in the history
…2841)

* Improve standard simplify rule matches in non-commutative contexts

Addresses the rule application limitation aspect as highlighted in
issue #2825; such that a broader set of successful standard replacement
rules are applied to multi-arg/associative expressions in
non-commutative contexts.

* Remove 'clone()' operations on expanded simplify rules

since original rule nodes (including expanded variations) are essentially
readonly objects, cloning of expanded rule LHS' is unnecessary during
canonicalization

* Hoist non-commutative context expanded rule app. in simplify (applyRule)

* Add two simplify non-commutative ctx. test cases
  • Loading branch information
samueltlg authored Nov 22, 2022
1 parent f99020e commit 76c8c62
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
34 changes: 31 additions & 3 deletions src/function/algebra/simplify.js
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,34 @@ export const createSimplify = /* #__PURE__ */ factory(name, dependencies, (
}

if (isAssociative(newRule.l, context)) {
const nonCommutative = !isCommutative(newRule.l, context)
let leftExpandsym
// Gen. the LHS placeholder used in this NC-context specific expansion rules
if (nonCommutative) leftExpandsym = _getExpandPlaceholderSymbol()

const makeNode = createMakeNodeFunction(newRule.l)
const expandsym = _getExpandPlaceholderSymbol()
newRule.expanded = {}
newRule.expanded.l = makeNode([newRule.l.clone(), expandsym])
newRule.expanded.l = makeNode([newRule.l, expandsym])
// Push the expandsym into the deepest possible branch.
// This helps to match the newRule against nodes returned from getSplits() later on.
flatten(newRule.expanded.l, context)
unflattenr(newRule.expanded.l, context)
newRule.expanded.r = makeNode([newRule.r, expandsym])

// In and for a non-commutative context, attempting with yet additional expansion rules makes
// way for more matches cases of multi-arg expressions; such that associative rules (such as
// 'n*n -> n^2') can be applied to exprs. such as 'a * b * b' and 'a * b * b * a'.
if (nonCommutative) {
// 'Non-commutative' 1: LHS (placeholder) only
newRule.expandedNC1 = {}
newRule.expandedNC1.l = makeNode([leftExpandsym, newRule.l])
newRule.expandedNC1.r = makeNode([leftExpandsym, newRule.r])
// 'Non-commutative' 2: farmost LHS and RHS placeholders
newRule.expandedNC2 = {}
newRule.expandedNC2.l = makeNode([leftExpandsym, newRule.expanded.l])
newRule.expandedNC2.r = makeNode([leftExpandsym, newRule.expanded.r])
}
}

return newRule
Expand Down Expand Up @@ -657,6 +676,15 @@ export const createSimplify = /* #__PURE__ */ factory(name, dependencies, (
repl = rule.expanded.r
matches = _ruleMatch(rule.expanded.l, res, mergedContext)[0]
}
// Additional, non-commutative context expansion-rules
if (!matches && rule.expandedNC1) {
repl = rule.expandedNC1.r
matches = _ruleMatch(rule.expandedNC1.l, res, mergedContext)[0]
if (!matches) { // Existence of NC1 implies NC2
repl = rule.expandedNC2.r
matches = _ruleMatch(rule.expandedNC2.l, res, mergedContext)[0]
}
}

if (matches) {
// const before = res.toString({parenthesis: 'all'})
Expand Down Expand Up @@ -880,8 +908,8 @@ export const createSimplify = /* #__PURE__ */ factory(name, dependencies, (
}
res = mergeChildMatches(childMatches)
} else if (node.args.length >= 2 && rule.args.length === 2) { // node is flattened, rule is not
// Associative operators/functions can be split in different ways so we check if the rule matches each
// them and return their union.
// Associative operators/functions can be split in different ways so we check if the rule
// matches for each of them and return their union.
const splits = getSplits(node, context)
let splitMatches = []
for (let i = 0; i < splits.length; i++) {
Expand Down
22 changes: 22 additions & 0 deletions test/unit-tests/function/algebra/simplify.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,35 @@ describe('simplify', function () {

it('should respect context changes to operator properties', function () {
const optsNCM = { context: { multiply: { commutative: false } } }
const optsNCA = { context: { add: { commutative: false } } }

simplifyAndCompare('x*y+y*x', 'x*y+y*x', {}, optsNCM)
simplifyAndCompare('x*y-y*x', 'x*y-y*x', {}, optsNCM)
simplifyAndCompare('x*5', 'x*5', {}, optsNCM)
simplifyAndCompare('x*y*x^(-1)', 'x*y*x^(-1)', {}, optsNCM)
simplifyAndCompare('x*y/x', 'x*y*x^(-1)', {}, optsNCM)
simplifyAndCompare('x*y*(1/x)', 'x*y*x^(-1)', {}, optsNCM)

// Rules apply to *segments* of operands in NC multi-arg. exprs.
// ('n*n->n^2')
simplifyAndCompare('n*n*3', 'n^2*3', {}, optsNCM)
simplifyAndCompare('3*n*n', '3*n^2', {}, optsNCM)
simplifyAndCompare('3*n*n*3', '3*n^2*3', {}, optsNCM)
simplifyAndCompare('3*n*n*n*3', '3*n^2*n*3', {}, optsNCM)
simplifyAndCompare('3*3*n*n*n*3', '9*n^2*n*3', {}, optsNCM)
simplifyAndCompare('(w*z)*n*n*3', 'w*z*n^2*3', {}, optsNCM)
simplifyAndCompare('2*n*n*3*n*n*4', '2*n^2*3*n^2*4', {}, optsNCM) // 'double wedged', +applied >1x
// ('v*(v*n1+n2) -> v^2*n1+v*n2')
simplifyAndCompare('w*x*(x*y+z)', 'w*(x^2*y+x*z)', {}, optsNCM)
simplifyAndCompare('w*x*(x*y+z)*w', 'w*(x^2*y+x*z)*w', {}, optsNCM)
// 'n+n -> 2*n'
simplifyAndCompare('x+x+3', '2*x+3', {}, optsNCA)
simplifyAndCompare('3+x+x', '3+2*x', {}, optsNCA)
simplifyAndCompare('4+x+x+4', '4+2*x+4', {}, optsNCA)
simplifyAndCompare('4+x+x+5+x+x+6', '4+2*x+5+2*x+6', {}, optsNCA) // 'double wedged', +applied >1x
// 'n+n -> 2*n' & 'n3*n1 + n3*n2 -> n3*(n1+n2)'
simplifyAndCompare('5+x+x+x+x+5', '5+4*x+5', {}, optsNCA)

const optsNAA = { context: { add: { associative: false } } }
simplifyAndCompare(
'x + (-x+y)', 'x + (y-x)', {}, optsNAA, { parenthesis: 'all' })
Expand Down

0 comments on commit 76c8c62

Please sign in to comment.