clang  20.0.0git
LoopUnrolling.cpp
Go to the documentation of this file.
1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file contains functions which are used to decide if a loop worth to be
10 /// unrolled. Moreover, these functions manages the stack of loop which is
11 /// tracked by the ProgramState.
12 ///
13 //===----------------------------------------------------------------------===//
14 
20 #include <optional>
21 
22 using namespace clang;
23 using namespace ento;
24 using namespace clang::ast_matchers;
25 
26 static const int MAXIMUM_STEP_UNROLLED = 128;
27 
28 namespace {
29 struct LoopState {
30 private:
31  enum Kind { Normal, Unrolled } K;
32  const Stmt *LoopStmt;
33  const LocationContext *LCtx;
34  unsigned maxStep;
35  LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
36  : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
37 
38 public:
39  static LoopState getNormal(const Stmt *S, const LocationContext *L,
40  unsigned N) {
41  return LoopState(Normal, S, L, N);
42  }
43  static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
44  unsigned N) {
45  return LoopState(Unrolled, S, L, N);
46  }
47  bool isUnrolled() const { return K == Unrolled; }
48  unsigned getMaxStep() const { return maxStep; }
49  const Stmt *getLoopStmt() const { return LoopStmt; }
50  const LocationContext *getLocationContext() const { return LCtx; }
51  bool operator==(const LoopState &X) const {
52  return K == X.K && LoopStmt == X.LoopStmt;
53  }
54  void Profile(llvm::FoldingSetNodeID &ID) const {
55  ID.AddInteger(K);
56  ID.AddPointer(LoopStmt);
57  ID.AddPointer(LCtx);
58  ID.AddInteger(maxStep);
59  }
60 };
61 } // namespace
62 
63 // The tracked stack of loops. The stack indicates that which loops the
64 // simulated element contained by. The loops are marked depending if we decided
65 // to unroll them.
66 // TODO: The loop stack should not need to be in the program state since it is
67 // lexical in nature. Instead, the stack of loops should be tracked in the
68 // LocationContext.
69 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
70 
71 namespace clang {
72 namespace ento {
73 
74 static bool isLoopStmt(const Stmt *S) {
75  return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S);
76 }
77 
79  auto LS = State->get<LoopStack>();
80  if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
81  State = State->set<LoopStack>(LS.getTail());
82  return State;
83 }
84 
85 static internal::Matcher<Stmt> simpleCondition(StringRef BindName,
86  StringRef RefName) {
87  return binaryOperator(
88  anyOf(hasOperatorName("<"), hasOperatorName(">"),
89  hasOperatorName("<="), hasOperatorName(">="),
90  hasOperatorName("!=")),
91  hasEitherOperand(ignoringParenImpCasts(
92  declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName)))
93  .bind(RefName))),
94  hasEitherOperand(
95  ignoringParenImpCasts(integerLiteral().bind("boundNum"))))
96  .bind("conditionOperator");
97 }
98 
99 static internal::Matcher<Stmt>
100 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
101  return anyOf(
102  unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
103  hasUnaryOperand(ignoringParenImpCasts(
104  declRefExpr(to(varDecl(VarNodeMatcher)))))),
105  binaryOperator(isAssignmentOperator(),
106  hasLHS(ignoringParenImpCasts(
107  declRefExpr(to(varDecl(VarNodeMatcher)))))));
108 }
109 
110 static internal::Matcher<Stmt>
111 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
112  return callExpr(forEachArgumentWithParam(
113  declRefExpr(to(varDecl(VarNodeMatcher))),
114  parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
115 }
116 
117 static internal::Matcher<Stmt>
118 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
120  allOf(hasType(referenceType()),
121  hasInitializer(anyOf(
122  initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
123  declRefExpr(to(varDecl(VarNodeMatcher)))))))));
124 }
125 
126 static internal::Matcher<Stmt>
127 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
128  return unaryOperator(
129  hasOperatorName("&"),
130  hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
131 }
132 
133 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
134  return hasDescendant(stmt(
136  // Escaping and not known mutation of the loop counter is handled
137  // by exclusion of assigning and address-of operators and
138  // pass-by-ref function calls on the loop counter from the body.
139  changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
140  callByRef(equalsBoundNode(std::string(NodeName))),
141  getAddrTo(equalsBoundNode(std::string(NodeName))),
142  assignedToRef(equalsBoundNode(std::string(NodeName))))));
143 }
144 
145 static internal::Matcher<Stmt> forLoopMatcher() {
146  return forStmt(
147  hasCondition(simpleCondition("initVarName", "initVarRef")),
148  // Initialization should match the form: 'int i = 6' or 'i = 42'.
149  hasLoopInit(
150  anyOf(declStmt(hasSingleDecl(
151  varDecl(allOf(hasInitializer(ignoringParenImpCasts(
152  integerLiteral().bind("initNum"))),
153  equalsBoundNode("initVarName"))))),
155  equalsBoundNode("initVarName"))))),
156  hasRHS(ignoringParenImpCasts(
157  integerLiteral().bind("initNum")))))),
158  // Incrementation should be a simple increment or decrement
159  // operator call.
160  hasIncrement(unaryOperator(
161  anyOf(hasOperatorName("++"), hasOperatorName("--")),
162  hasUnaryOperand(declRefExpr(
163  to(varDecl(allOf(equalsBoundNode("initVarName"),
164  hasType(isInteger())))))))),
165  unless(hasBody(hasSuspiciousStmt("initVarName"))))
166  .bind("forLoop");
167 }
168 
169 static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
170 
171  // Get the lambda CXXRecordDecl
173  const LocationContext *LocCtxt = N->getLocationContext();
174  const Decl *D = LocCtxt->getDecl();
175  const auto *MD = cast<CXXMethodDecl>(D);
176  assert(MD && MD->getParent()->isLambda() &&
177  "Captured variable should only be seen while evaluating a lambda");
178  const CXXRecordDecl *LambdaCXXRec = MD->getParent();
179 
180  // Lookup the fields of the lambda
181  llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
182  FieldDecl *LambdaThisCaptureField;
183  LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField);
184 
185  // Check if the counter is captured by reference
186  const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
187  assert(VD);
188  const FieldDecl *FD = LambdaCaptureFields[VD];
189  assert(FD && "Captured variable without a corresponding field");
190  return FD->getType()->isReferenceType();
191 }
192 
193 static bool isFoundInStmt(const Stmt *S, const VarDecl *VD) {
194  if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
195  for (const Decl *D : DS->decls()) {
196  // Once we reach the declaration of the VD we can return.
197  if (D->getCanonicalDecl() == VD)
198  return true;
199  }
200  }
201  return false;
202 }
203 
204 // A loop counter is considered escaped if:
205 // case 1: It is a global variable.
206 // case 2: It is a reference parameter or a reference capture.
207 // case 3: It is assigned to a non-const reference variable or parameter.
208 // case 4: Has its address taken.
209 static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
210  const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
211  assert(VD);
212  // Case 1:
213  if (VD->hasGlobalStorage())
214  return true;
215 
216  const bool IsRefParamOrCapture =
217  isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture();
218  // Case 2:
220  isCapturedByReference(N, DR)) ||
221  (IsRefParamOrCapture && VD->getType()->isReferenceType()))
222  return true;
223 
224  while (!N->pred_empty()) {
225  // FIXME: getStmtForDiagnostics() does nasty things in order to provide
226  // a valid statement for body farms, do we need this behavior here?
227  const Stmt *S = N->getStmtForDiagnostics();
228  if (!S) {
229  N = N->getFirstPred();
230  continue;
231  }
232 
233  if (isFoundInStmt(S, VD)) {
234  return false;
235  }
236 
237  if (const auto *SS = dyn_cast<SwitchStmt>(S)) {
238  if (const auto *CST = dyn_cast<CompoundStmt>(SS->getBody())) {
239  for (const Stmt *CB : CST->body()) {
240  if (isFoundInStmt(CB, VD))
241  return false;
242  }
243  }
244  }
245 
246  // Check the usage of the pass-by-ref function calls and adress-of operator
247  // on VD and reference initialized by VD.
248  ASTContext &ASTCtx =
250  // Case 3 and 4:
251  auto Match =
252  match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
253  assignedToRef(equalsNode(VD)))),
254  *S, ASTCtx);
255  if (!Match.empty())
256  return true;
257 
258  N = N->getFirstPred();
259  }
260 
261  // Reference parameter and reference capture will not be found.
262  if (IsRefParamOrCapture)
263  return false;
264 
265  llvm_unreachable("Reached root without finding the declaration of VD");
266 }
267 
268 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
269  ExplodedNode *Pred, unsigned &maxStep) {
270 
271  if (!isLoopStmt(LoopStmt))
272  return false;
273 
274  // TODO: Match the cases where the bound is not a concrete literal but an
275  // integer with known value
276  auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
277  if (Matches.empty())
278  return false;
279 
280  const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef");
281  llvm::APInt BoundNum =
282  Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
283  llvm::APInt InitNum =
284  Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
285  auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
286  if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
287  InitNum = InitNum.zext(BoundNum.getBitWidth());
288  BoundNum = BoundNum.zext(InitNum.getBitWidth());
289  }
290 
291  if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
292  maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
293  else
294  maxStep = (BoundNum - InitNum).abs().getZExtValue();
295 
296  // Check if the counter of the loop is not escaped before.
297  return !isPossiblyEscaped(Pred, CounterVarRef);
298 }
299 
300 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
301  const Stmt *S = nullptr;
302  while (!N->pred_empty()) {
303  if (N->succ_size() > 1)
304  return true;
305 
306  ProgramPoint P = N->getLocation();
307  if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
308  S = BE->getBlock()->getTerminatorStmt();
309 
310  if (S == LoopStmt)
311  return false;
312 
313  N = N->getFirstPred();
314  }
315 
316  llvm_unreachable("Reached root without encountering the previous step");
317 }
318 
319 // updateLoopStack is called on every basic block, therefore it needs to be fast
321  ExplodedNode *Pred, unsigned maxVisitOnPath) {
322  auto State = Pred->getState();
323  auto LCtx = Pred->getLocationContext();
324 
325  if (!isLoopStmt(LoopStmt))
326  return State;
327 
328  auto LS = State->get<LoopStack>();
329  if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
330  LCtx == LS.getHead().getLocationContext()) {
331  if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
332  State = State->set<LoopStack>(LS.getTail());
333  State = State->add<LoopStack>(
334  LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
335  }
336  return State;
337  }
338  unsigned maxStep;
339  if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
340  State = State->add<LoopStack>(
341  LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
342  return State;
343  }
344 
345  unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
346 
347  unsigned innerMaxStep = maxStep * outerStep;
348  if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
349  State = State->add<LoopStack>(
350  LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
351  else
352  State = State->add<LoopStack>(
353  LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
354  return State;
355 }
356 
358  auto LS = State->get<LoopStack>();
359  if (LS.isEmpty() || !LS.getHead().isUnrolled())
360  return false;
361  return true;
362 }
363 }
364 }
StringRef P
static char ID
Definition: Arena.cpp:183
const Decl * D
enum clang::sema::@1659::IndirectLocalPathEntry::EntryKind Kind
#define X(type, name)
Definition: Value.h:143
static const int MAXIMUM_STEP_UNROLLED
This header contains the declarations of functions which are used to decide which loops should be com...
#define REGISTER_LIST_WITH_PROGRAMSTATE(Name, Elem)
Declares an immutable list type NameTy, suitable for placement into the ProgramState.
LineState State
__DEVICE__ long long abs(long long __n)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition: ASTContext.h:187
ASTContext & getASTContext() const
A builtin binary operation expression such as "x + y" or "x <= y".
Definition: Expr.h:3912
Represents a C++ struct/union/class.
Definition: DeclCXX.h:258
void getCaptureFields(llvm::DenseMap< const ValueDecl *, FieldDecl * > &Captures, FieldDecl *&ThisCapture) const
For a closure type, retrieve the mapping from captured variables and this to the non-static data memb...
Definition: DeclCXX.cpp:1680
DeclContext * getParent()
getParent - Returns the containing DeclContext.
Definition: DeclBase.h:2090
A reference to a declared variable, function, enum, etc.
Definition: Expr.h:1265
ValueDecl * getDecl()
Definition: Expr.h:1333
bool refersToEnclosingVariableOrCapture() const
Does this DeclRefExpr refer to an enclosing local or a captured variable?
Definition: Expr.h:1463
DeclStmt - Adaptor class for mixing declarations with statements and expressions.
Definition: Stmt.h:1497
Decl - This represents one declaration (or definition), e.g.
Definition: DeclBase.h:86
virtual Decl * getCanonicalDecl()
Retrieves the "canonical" declaration of the given declaration.
Definition: DeclBase.h:968
Represents a member of a struct/union/class.
Definition: Decl.h:3031
It wraps the AnalysisDeclContext to represent both the call stack with the help of StackFrameContext ...
const Decl * getDecl() const
LLVM_ATTRIBUTE_RETURNS_NONNULL AnalysisDeclContext * getAnalysisDeclContext() const
Stmt - This represents one statement.
Definition: Stmt.h:84
bool isReferenceType() const
Definition: Type.h:8031
QualType getType() const
Definition: Decl.h:679
Represents a variable declaration or definition.
Definition: Decl.h:880
bool hasGlobalStorage() const
Returns true for all variables that do not have local storage.
Definition: Decl.h:1175
ExplodedNode * getFirstPred()
const ProgramStateRef & getState() const
const LocationContext * getLocationContext() const
const Stmt * getStmtForDiagnostics() const
If the node's program point corresponds to a statement, retrieve that statement.
ProgramPoint getLocation() const
getLocation - Returns the edge associated with the given node.
unsigned succ_size() const
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
const internal::VariadicDynCastAllOfMatcher< Stmt, DeclRefExpr > declRefExpr
Matches expressions that refer to declarations.
const internal::VariadicOperatorMatcherFunc< 1, 1 > unless
Matches if the provided matcher does not match.
const internal::ArgumentAdaptingMatcherFunc< internal::HasDescendantMatcher > hasDescendant
Matches AST nodes that have descendant AST nodes that match the provided matcher.
const internal::VariadicDynCastAllOfMatcher< Decl, ParmVarDecl > parmVarDecl
Matches parameter variable declarations.
const internal::VariadicDynCastAllOfMatcher< Stmt, ReturnStmt > returnStmt
Matches return statements.
const internal::VariadicDynCastAllOfMatcher< Stmt, CallExpr > callExpr
Matches call expressions.
SmallVector< BoundNodes, 1 > match(MatcherT Matcher, const NodeT &Node, ASTContext &Context)
Returns the results of matching Matcher on Node.
const internal::VariadicDynCastAllOfMatcher< Stmt, UnaryOperator > unaryOperator
Matches unary operator expressions.
const internal::VariadicDynCastAllOfMatcher< Stmt, InitListExpr > initListExpr
Matches init list expressions.
const internal::VariadicDynCastAllOfMatcher< Stmt, ForStmt > forStmt
Matches for statements.
const internal::VariadicDynCastAllOfMatcher< Stmt, GotoStmt > gotoStmt
Matches goto statements.
const internal::VariadicDynCastAllOfMatcher< Stmt, BinaryOperator > binaryOperator
Matches binary operator expressions.
const internal::ArgumentAdaptingMatcherFunc< internal::HasMatcher > has
Matches AST nodes that have child AST nodes that match the provided matcher.
internal::PolymorphicMatcher< internal::HasDeclarationMatcher, void(internal::HasDeclarationSupportedTypes), internal::Matcher< Decl > > hasDeclaration(const internal::Matcher< Decl > &InnerMatcher)
Matches a node if the declaration associated with that node matches the given matcher.
Definition: ASTMatchers.h:3653
const internal::VariadicOperatorMatcherFunc< 2, std::numeric_limits< unsigned >::max()> allOf
Matches if all given matchers match.
const internal::VariadicDynCastAllOfMatcher< Stmt, SwitchStmt > switchStmt
Matches switch statements.
const AstTypeMatcher< ReferenceType > referenceType
Matches both lvalue and rvalue reference types.
const internal::VariadicDynCastAllOfMatcher< Stmt, IntegerLiteral > integerLiteral
Matches integer literals of all sizes / encodings, e.g.
const internal::VariadicDynCastAllOfMatcher< Stmt, DeclStmt > declStmt
Matches declaration statements.
const internal::VariadicAllOfMatcher< Stmt > stmt
Matches statements.
const internal::VariadicOperatorMatcherFunc< 2, std::numeric_limits< unsigned >::max()> anyOf
Matches if any of the given matchers matches.
const internal::VariadicAllOfMatcher< QualType > qualType
Matches QualTypes in the clang AST.
static internal::Matcher< Stmt > forLoopMatcher()
static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR)
static internal::Matcher< Stmt > hasSuspiciousStmt(StringRef NodeName)
static bool isLoopStmt(const Stmt *S)
static internal::Matcher< Stmt > changeIntBoundNode(internal::Matcher< Decl > VarNodeMatcher)
ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State)
Updates the given ProgramState.
bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, ExplodedNode *Pred, unsigned &maxStep)
static internal::Matcher< Stmt > simpleCondition(StringRef BindName, StringRef RefName)
static bool isFoundInStmt(const Stmt *S, const VarDecl *VD)
bool isUnrolledState(ProgramStateRef State)
Returns if the given State indicates that is inside a completely unrolled loop.
static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR)
static internal::Matcher< Stmt > assignedToRef(internal::Matcher< Decl > VarNodeMatcher)
static internal::Matcher< Stmt > callByRef(internal::Matcher< Decl > VarNodeMatcher)
static internal::Matcher< Stmt > getAddrTo(internal::Matcher< Decl > VarNodeMatcher)
ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, ExplodedNode *Pred, unsigned maxVisitOnPath)
Updates the stack of loops contained by the ProgramState.
bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt)
llvm::APInt APInt
Definition: Integral.h:29
The JSON file list parser is used to communicate input to InstallAPI.
bool operator==(const CallGraphNode::CallRecord &LHS, const CallGraphNode::CallRecord &RHS)
Definition: CallGraph.h:223