Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust: Take nested functions into account when resolving variables #18482

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 109 additions & 44 deletions rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -397,20 +397,23 @@ module Impl {
)
}

private newtype TVariableOrAccessCand =
TVariableOrAccessCandVariable(Variable v) or
TVariableOrAccessCandVariableAccessCand(VariableAccessCand va)
private newtype TDefOrAccessCand =
TDefOrAccessCandNestedFunction(Function f, BlockExprScope scope) {
f = scope.getStmtList().getAStatement()
} or
TDefOrAccessCandVariable(Variable v) or
TDefOrAccessCandVariableAccessCand(VariableAccessCand va)

/**
* A variable declaration or variable access candidate.
* A nested function declaration, variable declaration, or variable (or function)
* access candidate.
*
* In order to determine whether a candidate is an actual variable access,
* we rank declarations and candidates by their position in source code.
* In order to determine whether a candidate is an actual variable/function access,
* we rank declarations and candidates by their position in the AST.
*
* The ranking must take variable names into account, but also variable scopes;
* below a comment `rank(scope, name, i)` means that the declaration/access on
* the given line has rank `i` amongst all declarations/accesses inside variable
* scope `scope`, for variable name `name`:
* The ranking must take names into account, but also variable scopes; below a comment
* `rank(scope, name, i)` means that the declaration/access on the given line has rank
* `i` amongst all declarations/accesses inside variable scope `scope`, for name `name`:
*
* ```rust
* fn f() { // scope0
Expand All @@ -430,8 +433,8 @@ module Impl {
* }
* ```
*
* Variable declarations are only ranked in the scope that they bind into, while
* accesses candidates propagate outwards through scopes, as they may access
* Function/variable declarations are only ranked in the scope that they bind into,
* while accesses candidates propagate outwards through scopes, as they may access
* declarations from outer scopes.
*
* For an access candidate with ranks `{ rank(scope_i, name, rnk_i) | i in I }` and
Expand All @@ -448,41 +451,80 @@ module Impl {
* i.e., its the nearest declaration before the access in the same (or outer) scope
* as the access.
*/
private class VariableOrAccessCand extends TVariableOrAccessCand {
Variable asVariable() { this = TVariableOrAccessCandVariable(result) }
abstract private class DefOrAccessCand extends TDefOrAccessCand {
abstract string toString();

VariableAccessCand asVariableAccessCand() {
this = TVariableOrAccessCandVariableAccessCand(result)
}
abstract Location getLocation();

string toString() {
result = this.asVariable().toString() or result = this.asVariableAccessCand().toString()
}
pragma[nomagic]
abstract predicate rankBy(string name, VariableScope scope, int ord, int kind);
}

Location getLocation() {
result = this.asVariable().getLocation() or result = this.asVariableAccessCand().getLocation()
}
abstract private class NestedFunctionOrVariable extends DefOrAccessCand { }

pragma[nomagic]
predicate rankBy(string name, VariableScope scope, int ord, int kind) {
variableDeclInScope(this.asVariable(), scope, name, ord) and
private class DefOrAccessCandNestedFunction extends NestedFunctionOrVariable,
TDefOrAccessCandNestedFunction
{
private Function f;
private BlockExprScope scope_;

DefOrAccessCandNestedFunction() { this = TDefOrAccessCandNestedFunction(f, scope_) }

override string toString() { result = f.toString() }

override Location getLocation() { result = f.getLocation() }

override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
// nested functions behave as if they are defined at the beginning of the scope
name = f.getName().getText() and
scope = scope_ and
ord = 0 and
kind = 0
or
variableAccessCandInScope(this.asVariableAccessCand(), scope, name, _, ord) and
}
}

private class DefOrAccessCandVariable extends NestedFunctionOrVariable, TDefOrAccessCandVariable {
private Variable v;

DefOrAccessCandVariable() { this = TDefOrAccessCandVariable(v) }

override string toString() { result = v.toString() }

override Location getLocation() { result = v.getLocation() }

override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
variableDeclInScope(v, scope, name, ord) and
kind = 1
}
}

private class DefOrAccessCandVariableAccessCand extends DefOrAccessCand,
TDefOrAccessCandVariableAccessCand
{
private VariableAccessCand va;

DefOrAccessCandVariableAccessCand() { this = TDefOrAccessCandVariableAccessCand(va) }

override string toString() { result = va.toString() }

override Location getLocation() { result = va.getLocation() }

override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
variableAccessCandInScope(va, scope, name, _, ord) and
kind = 2
}
}

private module DenseRankInput implements DenseRankInputSig2 {
class C1 = VariableScope;

class C2 = string;

class Ranked = VariableOrAccessCand;
class Ranked = DefOrAccessCand;

int getRank(VariableScope scope, string name, VariableOrAccessCand v) {
int getRank(VariableScope scope, string name, DefOrAccessCand v) {
v =
rank[result](VariableOrAccessCand v0, int ord, int kind |
rank[result](DefOrAccessCand v0, int ord, int kind |
v0.rankBy(name, scope, ord, kind)
|
v0 order by ord, kind
Expand All @@ -494,7 +536,7 @@ module Impl {
* Gets the rank of `v` amongst all other declarations or access candidates
* to a variable named `name` in the variable scope `scope`.
*/
private int rankVariableOrAccess(VariableScope scope, string name, VariableOrAccessCand v) {
private int rankVariableOrAccess(VariableScope scope, string name, DefOrAccessCand v) {
v = DenseRank2<DenseRankInput>::denseRank(scope, name, result + 1)
}

Expand All @@ -512,25 +554,38 @@ module Impl {
* the declaration at rank 0 can only reach the access at rank 1, while the declaration
* at rank 2 can only reach the access at rank 3.
*/
private predicate variableReachesRank(VariableScope scope, string name, Variable v, int rnk) {
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariable(v))
private predicate variableReachesRank(
VariableScope scope, string name, NestedFunctionOrVariable v, int rnk
) {
rnk = rankVariableOrAccess(scope, name, v)
or
variableReachesRank(scope, name, v, rnk - 1) and
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariableAccessCand(_))
rnk = rankVariableOrAccess(scope, name, TDefOrAccessCandVariableAccessCand(_))
}

private predicate variableReachesCand(
VariableScope scope, string name, Variable v, VariableAccessCand cand, int nestLevel
VariableScope scope, string name, NestedFunctionOrVariable v, VariableAccessCand cand,
int nestLevel
) {
exists(int rnk |
variableReachesRank(scope, name, v, rnk) and
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariableAccessCand(cand)) and
rnk = rankVariableOrAccess(scope, name, TDefOrAccessCandVariableAccessCand(cand)) and
variableAccessCandInScope(cand, scope, name, nestLevel, _)
)
}

pragma[nomagic]
predicate access(string name, NestedFunctionOrVariable v, VariableAccessCand cand) {
v =
min(NestedFunctionOrVariable v0, int nestLevel |
variableReachesCand(_, name, v0, cand, nestLevel)
|
v0 order by nestLevel
)
}

/** A variable access. */
class VariableAccess extends PathExprBaseImpl::PathExprBase instanceof VariableAccessCand {
class VariableAccess extends PathExprBaseImpl::PathExprBase {
private string name;
private Variable v;

Expand Down Expand Up @@ -574,6 +629,16 @@ module Impl {
}
}

/** A nested function access. */
class NestedFunctionAccess extends PathExprBaseImpl::PathExprBase {
private Function f;

NestedFunctionAccess() { nestedFunctionAccess(_, f, this) }

/** Gets the function being accessed. */
Function getFunction() { result = f }
}

cached
private module Cached {
cached
Expand All @@ -582,12 +647,12 @@ module Impl {

cached
predicate variableAccess(string name, Variable v, VariableAccessCand cand) {
v =
min(Variable v0, int nestLevel |
variableReachesCand(_, name, v0, cand, nestLevel)
|
v0 order by nestLevel
)
access(name, TDefOrAccessCandVariable(v), cand)
}

cached
predicate nestedFunctionAccess(string name, Function f, VariableAccessCand cand) {
access(name, TDefOrAccessCandNestedFunction(f, _), cand)
}
}

Expand Down
91 changes: 54 additions & 37 deletions rust/ql/test/library-tests/dataflow/global/inline-flow.expected
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,28 @@ edges
| main.rs:41:26:44:5 | { ... } | main.rs:30:17:30:22 | ...: i64 | provenance | |
| main.rs:41:26:44:5 | { ... } | main.rs:41:13:44:6 | pass_through(...) | provenance | |
| main.rs:43:9:43:18 | source(...) | main.rs:41:26:44:5 | { ... } | provenance | |
| main.rs:56:23:56:28 | ...: i64 | main.rs:57:14:57:14 | n | provenance | |
| main.rs:59:31:65:5 | { ... } | main.rs:77:13:77:25 | mn.get_data(...) | provenance | |
| main.rs:63:13:63:21 | source(...) | main.rs:59:31:65:5 | { ... } | provenance | |
| main.rs:66:28:66:33 | ...: i64 | main.rs:66:43:72:5 | { ... } | provenance | |
| main.rs:77:9:77:9 | a | main.rs:78:10:78:10 | a | provenance | |
| main.rs:77:13:77:25 | mn.get_data(...) | main.rs:77:9:77:9 | a | provenance | |
| main.rs:83:9:83:9 | a | main.rs:84:16:84:16 | a | provenance | |
| main.rs:83:13:83:21 | source(...) | main.rs:83:9:83:9 | a | provenance | |
| main.rs:84:16:84:16 | a | main.rs:56:23:56:28 | ...: i64 | provenance | |
| main.rs:89:9:89:9 | a | main.rs:90:29:90:29 | a | provenance | |
| main.rs:89:13:89:21 | source(...) | main.rs:89:9:89:9 | a | provenance | |
| main.rs:90:9:90:9 | b | main.rs:91:10:91:10 | b | provenance | |
| main.rs:90:13:90:30 | mn.data_through(...) | main.rs:90:9:90:9 | b | provenance | |
| main.rs:90:29:90:29 | a | main.rs:66:28:66:33 | ...: i64 | provenance | |
| main.rs:90:29:90:29 | a | main.rs:90:13:90:30 | mn.data_through(...) | provenance | |
| main.rs:49:9:49:9 | a | main.rs:55:26:55:26 | a | provenance | |
| main.rs:49:13:49:22 | source(...) | main.rs:49:9:49:9 | a | provenance | |
| main.rs:51:21:51:26 | ...: i64 | main.rs:51:36:53:5 | { ... } | provenance | |
| main.rs:55:9:55:9 | b | main.rs:56:10:56:10 | b | provenance | |
| main.rs:55:13:55:27 | pass_through(...) | main.rs:55:9:55:9 | b | provenance | |
| main.rs:55:26:55:26 | a | main.rs:51:21:51:26 | ...: i64 | provenance | |
| main.rs:55:26:55:26 | a | main.rs:55:13:55:27 | pass_through(...) | provenance | |
| main.rs:67:23:67:28 | ...: i64 | main.rs:68:14:68:14 | n | provenance | |
| main.rs:70:31:76:5 | { ... } | main.rs:88:13:88:25 | mn.get_data(...) | provenance | |
| main.rs:74:13:74:21 | source(...) | main.rs:70:31:76:5 | { ... } | provenance | |
| main.rs:77:28:77:33 | ...: i64 | main.rs:77:43:83:5 | { ... } | provenance | |
| main.rs:88:9:88:9 | a | main.rs:89:10:89:10 | a | provenance | |
| main.rs:88:13:88:25 | mn.get_data(...) | main.rs:88:9:88:9 | a | provenance | |
| main.rs:94:9:94:9 | a | main.rs:95:16:95:16 | a | provenance | |
| main.rs:94:13:94:21 | source(...) | main.rs:94:9:94:9 | a | provenance | |
| main.rs:95:16:95:16 | a | main.rs:67:23:67:28 | ...: i64 | provenance | |
| main.rs:100:9:100:9 | a | main.rs:101:29:101:29 | a | provenance | |
| main.rs:100:13:100:21 | source(...) | main.rs:100:9:100:9 | a | provenance | |
| main.rs:101:9:101:9 | b | main.rs:102:10:102:10 | b | provenance | |
| main.rs:101:13:101:30 | mn.data_through(...) | main.rs:101:9:101:9 | b | provenance | |
| main.rs:101:29:101:29 | a | main.rs:77:28:77:33 | ...: i64 | provenance | |
| main.rs:101:29:101:29 | a | main.rs:101:13:101:30 | mn.data_through(...) | provenance | |
nodes
| main.rs:12:28:14:1 | { ... } | semmle.label | { ... } |
| main.rs:13:5:13:13 | source(...) | semmle.label | source(...) |
Expand All @@ -59,34 +66,44 @@ nodes
| main.rs:41:26:44:5 | { ... } | semmle.label | { ... } |
| main.rs:43:9:43:18 | source(...) | semmle.label | source(...) |
| main.rs:45:10:45:10 | a | semmle.label | a |
| main.rs:56:23:56:28 | ...: i64 | semmle.label | ...: i64 |
| main.rs:57:14:57:14 | n | semmle.label | n |
| main.rs:59:31:65:5 | { ... } | semmle.label | { ... } |
| main.rs:63:13:63:21 | source(...) | semmle.label | source(...) |
| main.rs:66:28:66:33 | ...: i64 | semmle.label | ...: i64 |
| main.rs:66:43:72:5 | { ... } | semmle.label | { ... } |
| main.rs:77:9:77:9 | a | semmle.label | a |
| main.rs:77:13:77:25 | mn.get_data(...) | semmle.label | mn.get_data(...) |
| main.rs:78:10:78:10 | a | semmle.label | a |
| main.rs:83:9:83:9 | a | semmle.label | a |
| main.rs:83:13:83:21 | source(...) | semmle.label | source(...) |
| main.rs:84:16:84:16 | a | semmle.label | a |
| main.rs:89:9:89:9 | a | semmle.label | a |
| main.rs:89:13:89:21 | source(...) | semmle.label | source(...) |
| main.rs:90:9:90:9 | b | semmle.label | b |
| main.rs:90:13:90:30 | mn.data_through(...) | semmle.label | mn.data_through(...) |
| main.rs:90:29:90:29 | a | semmle.label | a |
| main.rs:91:10:91:10 | b | semmle.label | b |
| main.rs:49:9:49:9 | a | semmle.label | a |
| main.rs:49:13:49:22 | source(...) | semmle.label | source(...) |
| main.rs:51:21:51:26 | ...: i64 | semmle.label | ...: i64 |
| main.rs:51:36:53:5 | { ... } | semmle.label | { ... } |
| main.rs:55:9:55:9 | b | semmle.label | b |
| main.rs:55:13:55:27 | pass_through(...) | semmle.label | pass_through(...) |
| main.rs:55:26:55:26 | a | semmle.label | a |
| main.rs:56:10:56:10 | b | semmle.label | b |
| main.rs:67:23:67:28 | ...: i64 | semmle.label | ...: i64 |
| main.rs:68:14:68:14 | n | semmle.label | n |
| main.rs:70:31:76:5 | { ... } | semmle.label | { ... } |
| main.rs:74:13:74:21 | source(...) | semmle.label | source(...) |
| main.rs:77:28:77:33 | ...: i64 | semmle.label | ...: i64 |
| main.rs:77:43:83:5 | { ... } | semmle.label | { ... } |
| main.rs:88:9:88:9 | a | semmle.label | a |
| main.rs:88:13:88:25 | mn.get_data(...) | semmle.label | mn.get_data(...) |
| main.rs:89:10:89:10 | a | semmle.label | a |
| main.rs:94:9:94:9 | a | semmle.label | a |
| main.rs:94:13:94:21 | source(...) | semmle.label | source(...) |
| main.rs:95:16:95:16 | a | semmle.label | a |
| main.rs:100:9:100:9 | a | semmle.label | a |
| main.rs:100:13:100:21 | source(...) | semmle.label | source(...) |
| main.rs:101:9:101:9 | b | semmle.label | b |
| main.rs:101:13:101:30 | mn.data_through(...) | semmle.label | mn.data_through(...) |
| main.rs:101:29:101:29 | a | semmle.label | a |
| main.rs:102:10:102:10 | b | semmle.label | b |
subpaths
| main.rs:36:26:36:26 | a | main.rs:30:17:30:22 | ...: i64 | main.rs:30:32:32:1 | { ... } | main.rs:36:13:36:27 | pass_through(...) |
| main.rs:41:26:44:5 | { ... } | main.rs:30:17:30:22 | ...: i64 | main.rs:30:32:32:1 | { ... } | main.rs:41:13:44:6 | pass_through(...) |
| main.rs:90:29:90:29 | a | main.rs:66:28:66:33 | ...: i64 | main.rs:66:43:72:5 | { ... } | main.rs:90:13:90:30 | mn.data_through(...) |
| main.rs:55:26:55:26 | a | main.rs:51:21:51:26 | ...: i64 | main.rs:51:36:53:5 | { ... } | main.rs:55:13:55:27 | pass_through(...) |
| main.rs:101:29:101:29 | a | main.rs:77:28:77:33 | ...: i64 | main.rs:77:43:83:5 | { ... } | main.rs:101:13:101:30 | mn.data_through(...) |
testFailures
#select
| main.rs:18:10:18:10 | a | main.rs:13:5:13:13 | source(...) | main.rs:18:10:18:10 | a | $@ | main.rs:13:5:13:13 | source(...) | source(...) |
| main.rs:22:10:22:10 | n | main.rs:26:13:26:21 | source(...) | main.rs:22:10:22:10 | n | $@ | main.rs:26:13:26:21 | source(...) | source(...) |
| main.rs:37:10:37:10 | b | main.rs:35:13:35:21 | source(...) | main.rs:37:10:37:10 | b | $@ | main.rs:35:13:35:21 | source(...) | source(...) |
| main.rs:45:10:45:10 | a | main.rs:43:9:43:18 | source(...) | main.rs:45:10:45:10 | a | $@ | main.rs:43:9:43:18 | source(...) | source(...) |
| main.rs:57:14:57:14 | n | main.rs:83:13:83:21 | source(...) | main.rs:57:14:57:14 | n | $@ | main.rs:83:13:83:21 | source(...) | source(...) |
| main.rs:78:10:78:10 | a | main.rs:63:13:63:21 | source(...) | main.rs:78:10:78:10 | a | $@ | main.rs:63:13:63:21 | source(...) | source(...) |
| main.rs:91:10:91:10 | b | main.rs:89:13:89:21 | source(...) | main.rs:91:10:91:10 | b | $@ | main.rs:89:13:89:21 | source(...) | source(...) |
| main.rs:56:10:56:10 | b | main.rs:49:13:49:22 | source(...) | main.rs:56:10:56:10 | b | $@ | main.rs:49:13:49:22 | source(...) | source(...) |
| main.rs:68:14:68:14 | n | main.rs:94:13:94:21 | source(...) | main.rs:68:14:68:14 | n | $@ | main.rs:94:13:94:21 | source(...) | source(...) |
| main.rs:89:10:89:10 | a | main.rs:74:13:74:21 | source(...) | main.rs:89:10:89:10 | a | $@ | main.rs:74:13:74:21 | source(...) | source(...) |
| main.rs:102:10:102:10 | b | main.rs:100:13:100:21 | source(...) | main.rs:102:10:102:10 | b | $@ | main.rs:100:13:100:21 | source(...) | source(...) |
12 changes: 12 additions & 0 deletions rust/ql/test/library-tests/dataflow/global/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ fn block_expression_as_argument() {
sink(a); // $ hasValueFlow=14
}

fn data_through_nested_function() {
let a = source(15);

fn pass_through(i: i64) -> i64 {
i
}

let b = pass_through(a);
sink(b); // $ hasValueFlow=15
}

// -----------------------------------------------------------------------------
// Data flow in, out, and through method.

Expand Down Expand Up @@ -127,6 +138,7 @@ fn main() {
data_out_of_call();
data_in_to_call();
data_through_call();
data_through_nested_function();

data_out_of_method();
data_in_to_method_call();
Expand Down
Loading