Skip to content

Commit 6a6ffa9

Browse files
Copilotwsmoses
andauthored
Add function existence validation to TypeAnalysisPrinter and ActivityAnalysisPrinter (#2616)
* Initial plan * Add function existence checks to TypeAnalysisPrinter and ActivityAnalysisPrinter Co-authored-by: wsmoses <[email protected]> * Fix compilation errors and add tests for function existence checks Co-authored-by: wsmoses <[email protected]> * Use std::once_flag for thread-safe validation in ActivityAnalysisPrinter NewPM Co-authored-by: wsmoses <[email protected]> * Improve code style with early returns and flatten nested conditionals Co-authored-by: wsmoses <[email protected]> * Final validation and testing complete Co-authored-by: wsmoses <[email protected]> * Remove CodeQL artifact * Replace report_fatal_error with EmitFailure to avoid crashing Use EmitFailure instead of report_fatal_error when function is not found. This provides a better error message and avoids crashing the compiler. Updated tests to match the new error message format. Co-authored-by: wsmoses <[email protected]> * Add function existence validation to TypeAnalysisPrinter and ActivityAnalysisPrinter Co-authored-by: wsmoses <[email protected]> * Refactor ActivityAnalysisPrinter to module pass and simplify function lookup - Changed ActivityAnalysisPrinterNewPM from function pass to module pass - Replaced std::call_once with M.getFunction for simpler function lookup - Changed "specified in" to "specified by" in error messages - Updated pass registration in Enzyme.cpp to use ModulePassManager - Removed unnecessary <mutex> include Co-authored-by: wsmoses <[email protected]> * Optimize FAM retrieval in ActivityAnalysisPrinter Move FunctionAnalysisManager retrieval outside the loop to avoid repeated calls Co-authored-by: wsmoses <[email protected]> * Remove legacy FunctionPass implementations and add empty function name validation - Deleted legacy ActivityAnalysisPrinter FunctionPass (no longer needed with module pass) - Deleted legacy TypeAnalysisPrinter FunctionPass (no longer needed with module pass) - Added validation for empty function names in both passes - Changed ActivityAnalysisPrinter to only run on target function instead of all functions - Both passes now emit diagnostic errors for empty function names Co-authored-by: wsmoses <[email protected]> * Optimize TypeAnalysisPrinter to only run on target function Changed TypeAnalysisPrinter to only run analysis on the target function instead of iterating through all functions, consistent with ActivityAnalysisPrinter Co-authored-by: wsmoses <[email protected]> * Restore legacy passes as ModulePasses - Added back ActivityAnalysisPrinter as a legacy ModulePass (converted from FunctionPass) - Added back TypeAnalysisPrinter as a legacy ModulePass (converted from FunctionPass) - Both legacy passes now have the same validation logic as NewPM passes - Legacy passes check for empty function names and non-existent functions - Keeps RegisterPass registration for backward compatibility Co-authored-by: wsmoses <[email protected]> * Extract getFirstFunctionDefinition helper function to reduce code duplication - Added getFirstFunctionDefinition() helper function in Utils.h/Utils.cpp - Returns the first non-declaration function in a module, or nullptr if none exists - Updated ActivityAnalysisPrinter to use the helper function (4 occurrences replaced) - Updated TypeAnalysisPrinter to use the helper function (4 occurrences replaced) - Eliminates code duplication and improves maintainability Co-authored-by: wsmoses <[email protected]> * Add EmitFailure overload for Module and simplify error handling - Added EmitFailure template overload that takes llvm::Module& instead of Function* - New overload automatically finds first function definition for diagnostic context - Simplified all error handling code in ActivityAnalysisPrinter and TypeAnalysisPrinter - Removed repetitive getFirstFunctionDefinition calls and if/else blocks - Added forward declaration of getFirstFunctionDefinition for template usage - Reduced code duplication and improved maintainability Co-authored-by: wsmoses <[email protected]> * Fix test expectations to match "specified by" error messages Updated test files to expect "specified by" instead of "specified in" to match the actual error messages produced by the code. Co-authored-by: wsmoses <[email protected]> * Apply clang-format to fix formatting issues - Fixed line wrapping for long comments - Removed trailing whitespace - Standardized indentation in arrays and function calls - Applied clang-format to all modified C++ files Co-authored-by: wsmoses <[email protected]> * Fix formatting with clang-format-16 and llvm style Applied clang-format-16 with llvm style (as required by CI) instead of clang-format-18. This changes indentation from 4 spaces to 2 spaces in array initializers to match the project's formatting requirements. Co-authored-by: wsmoses <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: wsmoses <[email protected]>
1 parent 2d4c57c commit 6a6ffa9

File tree

8 files changed

+141
-25
lines changed

8 files changed

+141
-25
lines changed

enzyme/Enzyme/ActivityAnalysisPrinter.cpp

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "llvm/Analysis/ScalarEvolution.h"
4848

4949
#include "llvm/Support/CommandLine.h"
50+
#include "llvm/Support/ErrorHandling.h"
5051

5152
#include "ActivityAnalysis.h"
5253
#include "ActivityAnalysisPrinter.h"
@@ -182,20 +183,36 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) {
182183
return /*changed*/ false;
183184
}
184185

185-
class ActivityAnalysisPrinter final : public FunctionPass {
186+
class ActivityAnalysisPrinter final : public ModulePass {
186187
public:
187188
static char ID;
188-
ActivityAnalysisPrinter() : FunctionPass(ID) {}
189+
ActivityAnalysisPrinter() : ModulePass(ID) {}
190+
191+
bool runOnModule(Module &M) override {
192+
// Check if function name is specified
193+
if (FunctionToAnalyze.empty()) {
194+
EmitFailure("NoFunctionSpecified", M,
195+
"No function specified for -activity-analysis-func");
196+
return false;
197+
}
189198

190-
void getAnalysisUsage(AnalysisUsage &AU) const override {
191-
AU.addRequired<TargetLibraryInfoWrapperPass>();
192-
}
199+
// Check if the specified function exists
200+
Function *TargetFunc = M.getFunction(FunctionToAnalyze);
193201

194-
bool runOnFunction(Function &F) override {
202+
if (!TargetFunc) {
203+
EmitFailure("FunctionNotFound", M, "Function '", FunctionToAnalyze,
204+
"' specified by -activity-analysis-func not found in module");
205+
return false;
206+
}
195207

196-
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
208+
// Run analysis only on the target function
209+
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(*TargetFunc);
210+
return printActivityAnalysis(*TargetFunc, TLI);
211+
}
197212

198-
return printActivityAnalysis(F, TLI);
213+
void getAnalysisUsage(AnalysisUsage &AU) const override {
214+
AU.addRequired<TargetLibraryInfoWrapperPass>();
215+
AU.setPreservesAll();
199216
}
200217
};
201218

@@ -207,10 +224,28 @@ static RegisterPass<ActivityAnalysisPrinter>
207224
X("print-activity-analysis", "Print Activity Analysis Results");
208225

209226
ActivityAnalysisPrinterNewPM::Result
210-
ActivityAnalysisPrinterNewPM::run(llvm::Function &F,
211-
llvm::FunctionAnalysisManager &FAM) {
212-
bool changed = false;
213-
changed = printActivityAnalysis(F, FAM.getResult<TargetLibraryAnalysis>(F));
227+
ActivityAnalysisPrinterNewPM::run(llvm::Module &M,
228+
llvm::ModuleAnalysisManager &MAM) {
229+
// Check if function name is specified
230+
if (FunctionToAnalyze.empty()) {
231+
EmitFailure("NoFunctionSpecified", M,
232+
"No function specified for -activity-analysis-func");
233+
return PreservedAnalyses::all();
234+
}
235+
236+
// Check if the specified function exists
237+
Function *TargetFunc = M.getFunction(FunctionToAnalyze);
238+
239+
if (!TargetFunc) {
240+
EmitFailure("FunctionNotFound", M, "Function '", FunctionToAnalyze,
241+
"' specified by -activity-analysis-func not found in module");
242+
return PreservedAnalyses::all();
243+
}
244+
245+
// Run analysis only on the target function
246+
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
247+
bool changed = printActivityAnalysis(
248+
*TargetFunc, FAM.getResult<TargetLibraryAnalysis>(*TargetFunc));
214249
return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
215250
}
216251
llvm::AnalysisKey ActivityAnalysisPrinterNewPM::Key;

enzyme/Enzyme/ActivityAnalysisPrinter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ActivityAnalysisPrinterNewPM final
4646
using Result = llvm::PreservedAnalyses;
4747
ActivityAnalysisPrinterNewPM() {}
4848

49-
Result run(llvm::Function &M, llvm::FunctionAnalysisManager &MAM);
49+
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM);
5050

5151
static bool isRequired() { return true; }
5252
};

enzyme/Enzyme/Enzyme.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3799,15 +3799,15 @@ extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
37993799
MPM.addPass(TypeAnalysisPrinterNewPM());
38003800
return true;
38013801
}
3802+
if (Name == "print-activity-analysis") {
3803+
MPM.addPass(ActivityAnalysisPrinterNewPM());
3804+
return true;
3805+
}
38023806
return false;
38033807
});
38043808
PB.registerPipelineParsingCallback(
38053809
[](llvm::StringRef Name, llvm::FunctionPassManager &FPM,
38063810
llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
3807-
if (Name == "print-activity-analysis") {
3808-
FPM.addPass(ActivityAnalysisPrinterNewPM());
3809-
return true;
3810-
}
38113811
if (Name == "jl-inst-simplify") {
38123812
FPM.addPass(JLInstSimplifyNewPM());
38133813
return true;

enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "llvm/Analysis/ScalarEvolution.h"
5353

5454
#include "llvm/Support/CommandLine.h"
55+
#include "llvm/Support/ErrorHandling.h"
5556

5657
#include "../EnzymeLogic.h"
5758
#include "../FunctionUtils.h"
@@ -165,16 +166,36 @@ bool printTypeAnalyses(llvm::Function &F) {
165166
}
166167
return /*changed*/ false;
167168
}
168-
class TypeAnalysisPrinter final : public FunctionPass {
169+
170+
class TypeAnalysisPrinter final : public ModulePass {
169171
public:
170172
static char ID;
171-
TypeAnalysisPrinter() : FunctionPass(ID) {}
173+
TypeAnalysisPrinter() : ModulePass(ID) {}
174+
175+
bool runOnModule(Module &M) override {
176+
// Check if function name is specified
177+
if (EnzymeFunctionToAnalyze.empty()) {
178+
EmitFailure("NoFunctionSpecified", M,
179+
"No function specified for -type-analysis-func");
180+
return false;
181+
}
172182

173-
void getAnalysisUsage(AnalysisUsage &AU) const override {
174-
AU.addRequired<TargetLibraryInfoWrapperPass>();
183+
// Check if the specified function exists
184+
Function *TargetFunc = M.getFunction(EnzymeFunctionToAnalyze);
185+
186+
if (!TargetFunc) {
187+
EmitFailure("FunctionNotFound", M, "Function '", EnzymeFunctionToAnalyze,
188+
"' specified by -type-analysis-func not found in module");
189+
return false;
190+
}
191+
192+
// Run analysis only on the target function
193+
return printTypeAnalyses(*TargetFunc);
175194
}
176195

177-
bool runOnFunction(Function &F) override { return printTypeAnalyses(F); }
196+
void getAnalysisUsage(AnalysisUsage &AU) const override {
197+
AU.setPreservesAll();
198+
}
178199
};
179200

180201
} // namespace
@@ -187,9 +208,24 @@ static RegisterPass<TypeAnalysisPrinter> X("print-type-analysis",
187208
TypeAnalysisPrinterNewPM::Result
188209
TypeAnalysisPrinterNewPM::run(llvm::Module &M,
189210
llvm::ModuleAnalysisManager &MAM) {
190-
bool changed = false;
191-
for (auto &F : M)
192-
changed |= printTypeAnalyses(F);
211+
// Check if function name is specified
212+
if (EnzymeFunctionToAnalyze.empty()) {
213+
EmitFailure("NoFunctionSpecified", M,
214+
"No function specified for -type-analysis-func");
215+
return PreservedAnalyses::all();
216+
}
217+
218+
// Check if the specified function exists
219+
Function *TargetFunc = M.getFunction(EnzymeFunctionToAnalyze);
220+
221+
if (!TargetFunc) {
222+
EmitFailure("FunctionNotFound", M, "Function '", EnzymeFunctionToAnalyze,
223+
"' specified by -type-analysis-func not found in module");
224+
return PreservedAnalyses::all();
225+
}
226+
227+
// Run analysis only on the target function
228+
bool changed = printTypeAnalyses(*TargetFunc);
193229
return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
194230
}
195231
llvm::AnalysisKey TypeAnalysisPrinterNewPM::Key;

enzyme/Enzyme/Utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3168,6 +3168,15 @@ Function *GetFunctionFromValue(Value *fn) {
31683168
return dyn_cast<Function>(GetFunctionValFromValue(fn));
31693169
}
31703170

3171+
Function *getFirstFunctionDefinition(Module &M) {
3172+
for (auto &F : M) {
3173+
if (!F.isDeclaration()) {
3174+
return &F;
3175+
}
3176+
}
3177+
return nullptr;
3178+
}
3179+
31713180
#if LLVM_VERSION_MAJOR >= 16
31723181
std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
31733182
#else

enzyme/Enzyme/Utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ class EnzymeFailure final : public llvm::DiagnosticInfoUnsupported {
195195
const llvm::Function *CodeRegion);
196196
};
197197

198+
// Forward declaration needed for EmitFailure template
199+
llvm::Function *getFirstFunctionDefinition(llvm::Module &M);
200+
198201
template <typename... Args>
199202
void EmitFailure(llvm::StringRef RemarkName,
200203
const llvm::DiagnosticLocation &Loc,
@@ -217,6 +220,21 @@ void EmitFailure(llvm::StringRef RemarkName,
217220
(EnzymeFailure("Enzyme: " + ss.str(), Loc, CodeRegion)));
218221
}
219222

223+
template <typename... Args>
224+
void EmitFailure(llvm::StringRef RemarkName, llvm::Module &M, Args &...args) {
225+
// Use the first function definition in the module as context for the
226+
// diagnostic
227+
if (llvm::Function *FirstFunc = getFirstFunctionDefinition(M)) {
228+
EmitFailure(RemarkName, FirstFunc->getSubprogram(), FirstFunc, args...);
229+
} else {
230+
// Fallback if no functions in module
231+
std::string *str = new std::string();
232+
llvm::raw_string_ostream ss(*str);
233+
(ss << ... << args);
234+
llvm::report_fatal_error(llvm::StringRef(*str));
235+
}
236+
}
237+
220238
static inline llvm::Function *isCalledFunction(llvm::Value *val) {
221239
if (llvm::CallInst *CI = llvm::dyn_cast<llvm::CallInst>(val)) {
222240
return CI->getCalledFunction();
@@ -1395,6 +1413,8 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
13951413

13961414
llvm::Function *GetFunctionFromValue(llvm::Value *fn);
13971415

1416+
llvm::Function *getFirstFunctionDefinition(llvm::Module &M);
1417+
13981418
llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0,
13991419
size_t preOffset = 0);
14001420

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not %opt < %s %newLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=nonexistent -S 2>&1 | FileCheck %s
2+
3+
define void @foo(i64* %x) {
4+
entry:
5+
ret void
6+
}
7+
8+
; CHECK: Enzyme: Function 'nonexistent' specified by -activity-analysis-func not found in module
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=nonexistent -S 2>&1 | FileCheck %s
2+
3+
define void @foo(i64* %x) {
4+
entry:
5+
ret void
6+
}
7+
8+
; CHECK: Enzyme: Function 'nonexistent' specified by -type-analysis-func not found in module

0 commit comments

Comments
 (0)