Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions rust/ql/lib/codeql/rust/elements/internal/TypeParamImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ module Impl {
* Gets the `index`th type bound of this type parameter, if any.
*
* This includes type bounds directly on this type parameter and bounds from
* any `where` clauses for this type parameter.
* any `where` clauses for this type parameter, but restricted to `where`
* clauses from the item that declares this type parameter.
*/
TypeBound getTypeBound(int index) {
result =
Expand All @@ -43,13 +44,36 @@ module Impl {
* Gets a type bound of this type parameter.
*
* This includes type bounds directly on this type parameter and bounds from
* any `where` clauses for this type parameter.
* any `where` clauses for this type parameter, but restricted to `where`
* clauses from the item that declares this type parameter.
*/
TypeBound getATypeBound() { result = this.getTypeBound(_) }

/** Holds if this type parameter has at least one type bound. */
predicate hasTypeBound() { exists(this.getATypeBound()) }

/**
* Gets the `index`th additional type bound of this type parameter,
* which applies to `constrainingItem`, if any.
*
* For example, in
*
* ```rust
* impl<T> SomeType<T> where T: Clone {
* fn foo() where T: Debug { }
* }
* ```
*
* The constraint `Debug` additionally applies to `T` in `foo`.
*/
TypeBound getAdditionalTypeBound(Item constrainingItem, int index) {
result =
rank[index + 1](int i, int j |
|
this.(TypeParamItemNode).getAdditionalTypeBoundAt(constrainingItem, i, j) order by i, j
)
}

override string toAbbreviatedString() { result = this.getName().getText() }

override string toStringImpl() { result = this.getName().getText() }
Expand Down
37 changes: 32 additions & 5 deletions rust/ql/lib/codeql/rust/internal/PathResolution.qll
Original file line number Diff line number Diff line change
Expand Up @@ -1162,21 +1162,39 @@ private Path getWherePredPath(WherePred wp) { result = wp.getTypeRepr().(PathTyp
final class TypeParamItemNode extends NamedItemNode, TypeItemNode instanceof TypeParam {
/** Gets a where predicate for this type parameter, if any */
pragma[nomagic]
private WherePred getAWherePred() {
private WherePred getAWherePred(ItemNode constrainingItem, boolean isAdditional) {
exists(ItemNode declaringItem |
this = declaringItem.getTypeParam(_) and
this = resolvePath(getWherePredPath(result)) and
result = declaringItem.getADescendant() and
this = declaringItem.getADescendant()
result = constrainingItem.getADescendant()
|
constrainingItem = declaringItem and
isAdditional = false
or
constrainingItem = declaringItem.getADescendant() and
isAdditional = true
)
}

pragma[nomagic]
TypeBound getTypeBoundAt(int i, int j) {
exists(TypeBoundList tbl | result = tbl.getBound(j) |
tbl = super.getTypeBoundList() and i = 0
tbl = super.getTypeBoundList() and
i = 0
or
exists(WherePred wp |
wp = this.getAWherePred() and
wp = this.getAWherePred(_, false) and
tbl = wp.getTypeBoundList() and
wp = any(WhereClause wc).getPredicate(i)
)
)
}

pragma[nomagic]
TypeBound getAdditionalTypeBoundAt(Item constrainingItem, int i, int j) {
exists(TypeBoundList tbl | result = tbl.getBound(j) |
exists(WherePred wp |
wp = this.getAWherePred(constrainingItem, true) and
tbl = wp.getTypeBoundList() and
wp = any(WhereClause wc).getPredicate(i)
)
Expand All @@ -1197,6 +1215,15 @@ final class TypeParamItemNode extends NamedItemNode, TypeItemNode instanceof Typ

ItemNode resolveABound() { result = resolvePath(this.getABoundPath()) }

pragma[nomagic]
ItemNode resolveAdditionalBound(ItemNode constrainingItem) {
result =
resolvePath(this.getAdditionalTypeBoundAt(constrainingItem, _, _)
.getTypeRepr()
.(PathTypeRepr)
.getPath())
}

override string getName() { result = TypeParam.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down
99 changes: 85 additions & 14 deletions rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private newtype TAssocFunctionType =
}

bindingset[abs, constraint, tp]
pragma[inline_late]
private Type getTraitConstraintTypeAt(
TypeAbstraction abs, TypeMention constraint, TypeParameter tp, TypePath path
) {
Expand Down Expand Up @@ -203,7 +204,7 @@ class AssocFunctionType extends MkAssocFunctionType {
}

pragma[nomagic]
Trait getALookupTrait(Type t) {
private Trait getALookupTrait(Type t) {
result = t.(TypeParamTypeParameter).getTypeParam().(TypeParamItemNode).resolveABound()
or
result = t.(SelfTypeParameter).getTrait()
Expand All @@ -213,23 +214,47 @@ Trait getALookupTrait(Type t) {
result = t.(DynTraitType).getTrait()
}

/**
* Gets the type obtained by substituting in relevant traits in which to do function
* lookup, or `t` itself when no such trait exist.
*/
pragma[nomagic]
Type substituteLookupTraits(Type t) {
private Trait getAdditionalLookupTrait(ItemNode i, Type t) {
result =
t.(TypeParamTypeParameter)
.getTypeParam()
.(TypeParamItemNode)
.resolveAdditionalBound(i.getImmediateParent*())
}

bindingset[n, t]
pragma[inline_late]
Trait getALookupTrait(AstNode n, Type t) {
result = getALookupTrait(t)
or
result = getAdditionalLookupTrait(any(ItemNode i | n = i.getADescendant()), t)
}

bindingset[i, t]
pragma[inline_late]
private Type substituteLookupTraits0(ItemNode i, Type t) {
not exists(getALookupTrait(t)) and
not exists(getAdditionalLookupTrait(i, t)) and
result = t
or
result = TTrait(getALookupTrait(t))
or
result = TTrait(getAdditionalLookupTrait(i, t))
}

/**
* Gets the `n`th `substituteLookupTraits` type for `t`, per some arbitrary order.
* Gets the type obtained by substituting in relevant traits in which to do function
* lookup, or `t` itself when no such trait exist, in the context of AST node `n`.
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this doc comment, grammar: "when no such trait exist" should be "when no such trait exists".

Suggested change
* lookup, or `t` itself when no such trait exist, in the context of AST node `n`.
* lookup, or `t` itself when no such trait exists, in the context of AST node `n`.

Copilot uses AI. Check for mistakes.
*/
bindingset[n, t]
pragma[inline_late]
Type substituteLookupTraits(AstNode n, Type t) {
result = substituteLookupTraits0(any(ItemNode i | n = i.getADescendant()), t)
}

pragma[nomagic]
Type getNthLookupType(Type t, int n) {
private Type getNthLookupType(Type t, int n) {
not exists(getALookupTrait(t)) and
result = t and
n = 0
Expand All @@ -244,24 +269,66 @@ Type getNthLookupType(Type t, int n) {
}

/**
* Gets the index of the last `substituteLookupTraits` type for `t`.
* Gets the `n`th `substituteLookupTraits` type for `t`, per some arbitrary order,
* in the context of AST node `node`.
*/
bindingset[node, t]
pragma[inline_late]
Type getNthLookupType(AstNode node, Type t, int n) {
exists(ItemNode i | node = i.getADescendant() |
if exists(getAdditionalLookupTrait(i, t))
then
result =
TTrait(rank[n + 1](Trait trait, int j |
trait = [getALookupTrait(t), getAdditionalLookupTrait(i, t)] and
j = idOfTypeParameterAstNode(trait)
|
trait order by j
))
else result = getNthLookupType(t, n)
)
}

pragma[nomagic]
int getLastLookupTypeIndex(Type t) { result = max(int n | exists(getNthLookupType(t, n))) }
private int getLastLookupTypeIndex(Type t) { result = max(int n | exists(getNthLookupType(t, n))) }

/**
* Gets the index of the last `substituteLookupTraits` type for `t`,
* in the context of AST node `node`.
*/
bindingset[node, t]
pragma[inline_late]
int getLastLookupTypeIndex(AstNode node, Type t) {
if exists(getAdditionalLookupTrait(node, t))
then result = max(int n | exists(getNthLookupType(node, t, n)))
else result = getLastLookupTypeIndex(t)
}

signature class ArgSig {
/** Gets the type of this argument at `path`. */
Type getTypeAt(TypePath path);

/** Gets the enclosing item node of this argument. */
ItemNode getEnclosingItemNode();

/** Gets a textual representation of this argument. */
string toString();

/** Gets the location of this argument. */
Location getLocation();
}

/**
* A wrapper around `IsInstantiationOf` which ensures to substitute in lookup
* traits when checking whether argument types are instantiations of function
* types.
*/
module ArgIsInstantiationOf<
HasTypeTreeSig Arg, IsInstantiationOfInputSig<Arg, AssocFunctionType> Input>
{
module ArgIsInstantiationOf<ArgSig Arg, IsInstantiationOfInputSig<Arg, AssocFunctionType> Input> {
final private class ArgFinal = Arg;

private class ArgSubst extends ArgFinal {
Type getTypeAt(TypePath path) {
result = substituteLookupTraits(super.getTypeAt(path)) and
result = substituteLookupTraits0(this.getEnclosingItemNode(), super.getTypeAt(path)) and
not result = TNeverType() and
not result = TUnknownType()
}
Expand Down Expand Up @@ -318,6 +385,8 @@ signature module ArgsAreInstantiationsOfInputSig {

Location getLocation();

ItemNode getEnclosingItemNode();

Type getArgType(FunctionPosition pos, TypePath path);

predicate hasTargetCand(ImplOrTraitItemNode i, Function f);
Expand Down Expand Up @@ -366,6 +435,8 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {

FunctionPosition getPos() { result = pos }

ItemNode getEnclosingItemNode() { result = call.getEnclosingItemNode() }

Location getLocation() { result = call.getLocation() }

Type getTypeAt(TypePath path) { result = call.getArgType(pos, path) }
Expand Down
Loading
Loading