blob: 222311dc040a8672331d8bcdd595cbda7342cc92 [file] [log] [blame]
//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/IPO/FunctionSpecialization.h"
#include "llvm/Transforms/Utils/SCCPSolver.h"
#include "gtest/gtest.h"
#include <memory>
namespace llvm {
class FunctionSpecializationTest : public testing::Test {
protected:
LLVMContext Ctx;
FunctionAnalysisManager FAM;
std::unique_ptr<Module> M;
std::unique_ptr<SCCPSolver> Solver;
FunctionSpecializationTest() {
FAM.registerPass([&] { return TargetLibraryAnalysis(); });
FAM.registerPass([&] { return TargetIRAnalysis(); });
FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
FAM.registerPass([&] { return LoopAnalysis(); });
FAM.registerPass([&] { return AssumptionAnalysis(); });
FAM.registerPass([&] { return DominatorTreeAnalysis(); });
FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
}
Module &parseModule(const char *ModuleString) {
SMDiagnostic Err;
M = parseAssemblyString(ModuleString, Err, Ctx);
EXPECT_TRUE(M);
return *M;
}
FunctionSpecializer getSpecializerFor(Function *F) {
auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
return FAM.getResult<TargetLibraryAnalysis>(F);
};
auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
return FAM.getResult<TargetIRAnalysis>(F);
};
auto GetAC = [this](Function &F) -> AssumptionCache & {
return FAM.getResult<AssumptionAnalysis>(F);
};
auto GetDT = [this](Function &F) -> DominatorTree & {
return FAM.getResult<DominatorTreeAnalysis>(F);
};
auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
return FAM.getResult<BlockFrequencyAnalysis>(F);
};
Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
DominatorTree &DT = GetDT(*F);
AssumptionCache &AC = GetAC(*F);
Solver->addPredicateInfo(*F, DT, AC);
Solver->markBlockExecutable(&F->front());
for (Argument &Arg : F->args())
Solver->markOverdefined(&Arg);
Solver->solveWhileResolvedUndefsIn(*M);
return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
GetAC);
}
Cost getInstCost(Instruction &I) {
auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
}
};
} // namespace llvm
using namespace llvm;
TEST_F(FunctionSpecializationTest, SwitchInst) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i32 %i) {
entry:
br label %loop
loop:
switch i32 %i, label %default
[ i32 1, label %case1
i32 2, label %case2 ]
case1:
%0 = mul i32 %a, 2
%1 = sub i32 6, 5
br label %bb1
case2:
%2 = and i32 %b, 3
%3 = sdiv i32 8, 2
br label %bb2
bb1:
%4 = add i32 %0, %b
br label %loop
bb2:
%5 = or i32 %2, %a
br label %loop
default:
ret void
}
)";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
auto FuncIter = F->begin();
++FuncIter;
BasicBlock &Case1 = *++FuncIter;
BasicBlock &Case2 = *++FuncIter;
BasicBlock &BB1 = *++FuncIter;
BasicBlock &BB2 = *++FuncIter;
Instruction &Mul = Case1.front();
Instruction &And = Case2.front();
Instruction &Sdiv = *++Case2.begin();
Instruction &BrBB2 = Case2.back();
Instruction &Add = BB1.front();
Instruction &Or = BB2.front();
Instruction &BrLoop = BB2.back();
// mul
Cost Ref = getInstCost(Mul);
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
// and + or + add
Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
// sdiv + br + br
Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrLoop);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
}
TEST_F(FunctionSpecializationTest, BranchInst) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i1 %cond) {
entry:
br label %loop
loop:
br i1 %cond, label %bb0, label %bb2
bb0:
%0 = mul i32 %a, 2
%1 = sub i32 6, 5
br label %bb1
bb1:
%2 = add i32 %0, %b
%3 = sdiv i32 8, 2
br label %loop
bb2:
ret void
}
)";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
Constant *False = ConstantInt::getFalse(M.getContext());
auto FuncIter = F->begin();
++FuncIter;
BasicBlock &BB0 = *++FuncIter;
BasicBlock &BB1 = *++FuncIter;
Instruction &Mul = BB0.front();
Instruction &Sub = *++BB0.begin();
Instruction &BrBB1 = BB0.back();
Instruction &Add = BB1.front();
Instruction &Sdiv = *++BB1.begin();
Instruction &BrLoop = BB1.back();
// mul
Cost Ref = getInstCost(Mul);
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
// add
Ref = getInstCost(Add);
Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
// sub + br + sdiv + br
Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
getInstCost(BrLoop);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
}
TEST_F(FunctionSpecializationTest, Misc) {
const char *ModuleString = R"(
%struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
@g = constant %struct_t zeroinitializer, align 16
declare i32 @llvm.smax.i32(i32, i32)
declare i32 @bar(i32)
define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
%cmp = icmp eq i8 %a, 10
%ext = zext i1 %cmp to i64
%sel = select i1 %cond, i64 %ext, i64 1
%gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
%ld = load i32, ptr %gep
%fr = freeze i32 %ld
%smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
%call = call i32 @bar(i32 %smax)
%fr2 = freeze i32 %c
%add = add i32 %call, %fr2
ret i32 %add
}
)";
Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
GlobalVariable *GV = M.getGlobalVariable("g");
Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
Constant *True = ConstantInt::getTrue(M.getContext());
Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
auto BlockIter = F->front().begin();
Instruction &Icmp = *BlockIter++;
Instruction &Zext = *BlockIter++;
Instruction &Select = *BlockIter++;
Instruction &Gep = *BlockIter++;
Instruction &Load = *BlockIter++;
Instruction &Freeze = *BlockIter++;
Instruction &Smax = *BlockIter++;
// icmp + zext
Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
// select
Ref = getInstCost(Select);
Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
// gep + load + freeze + smax
Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
getInstCost(Smax);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
EXPECT_TRUE(Bonus == 0);
}