Skip to content

Commit 7fd1e6b

Browse files
committed
Rust: Take additional type parameter constraints into account
1 parent 9ac960c commit 7fd1e6b

File tree

8 files changed

+226
-46
lines changed

8 files changed

+226
-46
lines changed

rust/ql/lib/codeql/rust/elements/internal/AstNodeImpl.qll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ module Impl {
4848
)
4949
}
5050

51+
/** Gets the immediately enclosing function of this node, if any. */
52+
pragma[nomagic]
53+
Function getEnclosingFunction() {
54+
result = this.getEnclosingCallable()
55+
or
56+
exists(Callable c | c = this.getEnclosingCallable() |
57+
not c instanceof Function and
58+
result = c.getEnclosingFunction()
59+
)
60+
}
61+
5162
/** Gets the CFG scope that encloses this node, if any. */
5263
cached
5364
CfgScope getEnclosingCfgScope() {

rust/ql/lib/codeql/rust/elements/internal/TypeParamImpl.qll

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ module Impl {
3232
* Gets the `index`th type bound of this type parameter, if any.
3333
*
3434
* This includes type bounds directly on this type parameter and bounds from
35-
* any `where` clauses for this type parameter.
35+
* any `where` clauses for this type parameter, but restricted to `where`
36+
* clauses from the item that declares this type parameter.
3637
*/
3738
TypeBound getTypeBound(int index) {
3839
result =
@@ -43,13 +44,36 @@ module Impl {
4344
* Gets a type bound of this type parameter.
4445
*
4546
* This includes type bounds directly on this type parameter and bounds from
46-
* any `where` clauses for this type parameter.
47+
* any `where` clauses for this type parameter, but restricted to `where`
48+
* clauses from the item that declares this type parameter.
4749
*/
4850
TypeBound getATypeBound() { result = this.getTypeBound(_) }
4951

5052
/** Holds if this type parameter has at least one type bound. */
5153
predicate hasTypeBound() { exists(this.getATypeBound()) }
5254

55+
/**
56+
* Gets the `index`th additional type bound of this type parameter,
57+
* which applies to `constrainingItem`, if any.
58+
*
59+
* For example, in
60+
*
61+
* ```rust
62+
* impl<T> SomeType<T> where T: Clone {
63+
* fn foo() where T: Debug { }
64+
* }
65+
* ```
66+
*
67+
* The constraint `Debug` additionally applies to `T` in `foo`.
68+
*/
69+
TypeBound getAdditionalTypeBound(Item constrainingItem, int index) {
70+
result =
71+
rank[index + 1](int i, int j |
72+
|
73+
this.(TypeParamItemNode).getAdditionalTypeBoundAt(constrainingItem, i, j) order by i, j
74+
)
75+
}
76+
5377
override string toAbbreviatedString() { result = this.getName().getText() }
5478

5579
override string toStringImpl() { result = this.getName().getText() }

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,21 +1162,39 @@ private Path getWherePredPath(WherePred wp) { result = wp.getTypeRepr().(PathTyp
11621162
final class TypeParamItemNode extends NamedItemNode, TypeItemNode instanceof TypeParam {
11631163
/** Gets a where predicate for this type parameter, if any */
11641164
pragma[nomagic]
1165-
private WherePred getAWherePred() {
1165+
private WherePred getAWherePred(ItemNode constrainingItem, boolean isAdditional) {
11661166
exists(ItemNode declaringItem |
1167+
this = declaringItem.getTypeParam(_) and
11671168
this = resolvePath(getWherePredPath(result)) and
1168-
result = declaringItem.getADescendant() and
1169-
this = declaringItem.getADescendant()
1169+
result = constrainingItem.getADescendant()
1170+
|
1171+
constrainingItem = declaringItem and
1172+
isAdditional = false
1173+
or
1174+
constrainingItem = declaringItem.getADescendant() and
1175+
isAdditional = true
11701176
)
11711177
}
11721178

11731179
pragma[nomagic]
11741180
TypeBound getTypeBoundAt(int i, int j) {
11751181
exists(TypeBoundList tbl | result = tbl.getBound(j) |
1176-
tbl = super.getTypeBoundList() and i = 0
1182+
tbl = super.getTypeBoundList() and
1183+
i = 0
11771184
or
11781185
exists(WherePred wp |
1179-
wp = this.getAWherePred() and
1186+
wp = this.getAWherePred(_, false) and
1187+
tbl = wp.getTypeBoundList() and
1188+
wp = any(WhereClause wc).getPredicate(i)
1189+
)
1190+
)
1191+
}
1192+
1193+
pragma[nomagic]
1194+
TypeBound getAdditionalTypeBoundAt(Item constrainingItem, int i, int j) {
1195+
exists(TypeBoundList tbl | result = tbl.getBound(j) |
1196+
exists(WherePred wp |
1197+
wp = this.getAWherePred(constrainingItem, true) and
11801198
tbl = wp.getTypeBoundList() and
11811199
wp = any(WhereClause wc).getPredicate(i)
11821200
)
@@ -1197,6 +1215,15 @@ final class TypeParamItemNode extends NamedItemNode, TypeItemNode instanceof Typ
11971215

11981216
ItemNode resolveABound() { result = resolvePath(this.getABoundPath()) }
11991217

1218+
pragma[nomagic]
1219+
ItemNode resolveAdditionalBound(ItemNode constrainingItem) {
1220+
result =
1221+
resolvePath(this.getAdditionalTypeBoundAt(constrainingItem, _, _)
1222+
.getTypeRepr()
1223+
.(PathTypeRepr)
1224+
.getPath())
1225+
}
1226+
12001227
override string getName() { result = TypeParam.super.getName().getText() }
12011228

12021229
override Namespace getNamespace() { result.isType() }

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class AssocFunctionType extends MkAssocFunctionType {
203203
}
204204

205205
pragma[nomagic]
206-
Trait getALookupTrait(Type t) {
206+
private Trait getALookupTrait(Type t) {
207207
result = t.(TypeParamTypeParameter).getTypeParam().(TypeParamItemNode).resolveABound()
208208
or
209209
result = t.(SelfTypeParameter).getTrait()
@@ -213,23 +213,40 @@ Trait getALookupTrait(Type t) {
213213
result = t.(DynTraitType).getTrait()
214214
}
215215

216-
/**
217-
* Gets the type obtained by substituting in relevant traits in which to do function
218-
* lookup, or `t` itself when no such trait exist.
219-
*/
220216
pragma[nomagic]
221-
Type substituteLookupTraits(Type t) {
217+
private Trait getAdditionalLookupTrait(Type t, Function f) {
218+
result = t.(TypeParamTypeParameter).getTypeParam().(TypeParamItemNode).resolveAdditionalBound(f)
219+
}
220+
221+
bindingset[n, t]
222+
Trait getALookupTrait(AstNode n, Type t) {
223+
result = getALookupTrait(t)
224+
or
225+
result = getAdditionalLookupTrait(t, n.getEnclosingFunction())
226+
}
227+
228+
bindingset[f, t]
229+
private Type substituteLookupTraits0(Function f, Type t) {
222230
not exists(getALookupTrait(t)) and
231+
not exists(getAdditionalLookupTrait(t, f)) and
223232
result = t
224233
or
225234
result = TTrait(getALookupTrait(t))
235+
or
236+
result = TTrait(getAdditionalLookupTrait(t, f))
226237
}
227238

228239
/**
229-
* Gets the `n`th `substituteLookupTraits` type for `t`, per some arbitrary order.
240+
* Gets the type obtained by substituting in relevant traits in which to do function
241+
* lookup, or `t` itself when no such trait exist, in the context of AST node `n`.
230242
*/
243+
bindingset[n, t]
244+
Type substituteLookupTraits(AstNode n, Type t) {
245+
result = substituteLookupTraits0(n.getEnclosingFunction(), t)
246+
}
247+
231248
pragma[nomagic]
232-
Type getNthLookupType(Type t, int n) {
249+
private Type getNthLookupType(Type t, int n) {
233250
not exists(getALookupTrait(t)) and
234251
result = t and
235252
n = 0
@@ -244,24 +261,64 @@ Type getNthLookupType(Type t, int n) {
244261
}
245262

246263
/**
247-
* Gets the index of the last `substituteLookupTraits` type for `t`.
264+
* Gets the `n`th `substituteLookupTraits` type for `t`, per some arbitrary order,
265+
* in the context of AST node `node`.
248266
*/
267+
bindingset[node, t]
268+
Type getNthLookupType(AstNode node, Type t, int n) {
269+
exists(Function f | f = node.getEnclosingFunction() |
270+
if exists(getAdditionalLookupTrait(t, f))
271+
then
272+
result =
273+
TTrait(rank[n + 1](Trait trait, int i |
274+
trait = [getALookupTrait(t), getAdditionalLookupTrait(t, f)] and
275+
i = idOfTypeParameterAstNode(trait)
276+
|
277+
trait order by i
278+
))
279+
else result = getNthLookupType(t, n)
280+
)
281+
}
282+
249283
pragma[nomagic]
250-
int getLastLookupTypeIndex(Type t) { result = max(int n | exists(getNthLookupType(t, n))) }
284+
private int getLastLookupTypeIndex(Type t) { result = max(int n | exists(getNthLookupType(t, n))) }
285+
286+
/**
287+
* Gets the index of the last `substituteLookupTraits` type for `t`,
288+
* in the context of AST node `node`.
289+
*/
290+
bindingset[node, t]
291+
int getLastLookupTypeIndex(AstNode node, Type t) {
292+
if exists(getAdditionalLookupTrait(t, node))
293+
then result = max(int n | exists(getNthLookupType(node, t, n)))
294+
else result = getLastLookupTypeIndex(t)
295+
}
296+
297+
signature class ArgSig {
298+
/** Gets the type of this argument at `path`. */
299+
Type getTypeAt(TypePath path);
300+
301+
/** Gets the enclosing function of this argument. */
302+
Function getEnclosingFunction();
303+
304+
/** Gets a textual representation of this argument. */
305+
string toString();
306+
307+
/** Gets the location of this argument. */
308+
Location getLocation();
309+
}
251310

252311
/**
253312
* A wrapper around `IsInstantiationOf` which ensures to substitute in lookup
254313
* traits when checking whether argument types are instantiations of function
255314
* types.
256315
*/
257-
module ArgIsInstantiationOf<
258-
HasTypeTreeSig Arg, IsInstantiationOfInputSig<Arg, AssocFunctionType> Input>
259-
{
316+
module ArgIsInstantiationOf<ArgSig Arg, IsInstantiationOfInputSig<Arg, AssocFunctionType> Input> {
260317
final private class ArgFinal = Arg;
261318

262319
private class ArgSubst extends ArgFinal {
263320
Type getTypeAt(TypePath path) {
264-
result = substituteLookupTraits(super.getTypeAt(path)) and
321+
result = substituteLookupTraits0(this.getEnclosingFunction(), super.getTypeAt(path)) and
265322
not result = TNeverType() and
266323
not result = TUnknownType()
267324
}
@@ -318,6 +375,8 @@ signature module ArgsAreInstantiationsOfInputSig {
318375

319376
Location getLocation();
320377

378+
Function getEnclosingFunction();
379+
321380
Type getArgType(FunctionPosition pos, TypePath path);
322381

323382
predicate hasTargetCand(ImplOrTraitItemNode i, Function f);
@@ -366,6 +425,8 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
366425

367426
FunctionPosition getPos() { result = pos }
368427

428+
Function getEnclosingFunction() { result = call.getEnclosingFunction() }
429+
369430
Location getLocation() { result = call.getLocation() }
370431

371432
Type getTypeAt(TypePath path) { result = call.getArgType(pos, path) }

0 commit comments

Comments
 (0)