| #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
| #include "llvm/Support/Compiler.h" |
| #include "gmock/gmock.h" |
| #include "gtest/gtest.h" |
| #include <memory> |
| |
| using namespace mlir; |
| using namespace mlir::sparse_tensor; |
| |
| // Silence 'warning C4002: 'too many arguments for function-liked macro |
| // invocation' |
| // as MSVC handles ##__VA_ARGS__ differently as gcc/clang |
| |
| #if defined(_MSC_VER) && !defined(__clang__) |
| #pragma warning(push) |
| #pragma warning(disable : 4002) |
| #endif |
| |
| namespace { |
| |
| /// |
| /// Defines macros to iterate binary and the combination of binary operations. |
| /// |
| |
| #define FOREVERY_BINOP(DO) \ |
| DO(mulf, Kind::kMulF) \ |
| DO(mulc, Kind::kMulC) \ |
| DO(muli, Kind::kMulI) \ |
| DO(addf, Kind::kAddF) \ |
| DO(addc, Kind::kAddC) \ |
| DO(addi, Kind::kAddI) \ |
| DO(subf, Kind::kSubF) \ |
| DO(subc, Kind::kSubC) \ |
| DO(subi, Kind::kSubI) \ |
| DO(andi, Kind::kAndI) \ |
| DO(xori, Kind::kXorI) \ |
| DO(ori, Kind::kOrI) |
| |
| // TODO: Disjunctive binary operations that need special handling are not |
| // included, e.g., Division are not tested (for now) as it need a constant |
| // non-zero dividend. |
| // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty. |
| #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...) \ |
| TEST(addf, ##__VA_ARGS__) \ |
| TEST(addc, ##__VA_ARGS__) \ |
| TEST(addi, ##__VA_ARGS__) \ |
| TEST(xori, ##__VA_ARGS__) \ |
| TEST(ori, ##__VA_ARGS__) |
| |
| // TODO: Conjunctive binary operations that need special handling are not |
| // included, e.g., substraction yields a different pattern as it is mapped to |
| // negate operation. |
| #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...) \ |
| TEST(mulf, ##__VA_ARGS__) \ |
| TEST(mulc, ##__VA_ARGS__) \ |
| TEST(muli, ##__VA_ARGS__) \ |
| TEST(andi, ##__VA_ARGS__) |
| |
| #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, addf) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, addc) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, addi) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, xori) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, ori) |
| |
| #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, mulf) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, mulc) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, muli) \ |
| FOREVERY_COMMON_CONJ_BINOP(TEST, andi) |
| |
| #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \ |
| FOREVERY_COMMON_DISJ_BINOP(TEST, addf) \ |
| FOREVERY_COMMON_DISJ_BINOP(TEST, addc) \ |
| FOREVERY_COMMON_DISJ_BINOP(TEST, addi) \ |
| FOREVERY_COMMON_DISJ_BINOP(TEST, ori) \ |
| FOREVERY_COMMON_DISJ_BINOP(TEST, xori) |
| |
| /// |
| /// Helper classes/functions for testing Merger. |
| /// |
| |
| /// Simple recursive data structure used to match expressions in Mergers. |
| struct Pattern { |
| Kind kind; |
| |
| /// Expressions representing tensors simply have a tensor number. |
| unsigned tensorNum; |
| |
| /// Tensor operations point to their children. |
| std::shared_ptr<Pattern> e0; |
| std::shared_ptr<Pattern> e1; |
| |
| /// Constructors. |
| /// Rather than using these, please use the readable helper constructor |
| /// functions below to make tests more readable. |
| Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} |
| Pattern(Kind kind, const std::shared_ptr<Pattern> &e0, |
| const std::shared_ptr<Pattern> &e1) |
| : kind(kind), e0(e0), e1(e1) { |
| assert(kind >= Kind::kMulF); |
| assert(e0 && e1); |
| } |
| }; |
| |
| /// |
| /// Readable Pattern builder functions. |
| /// These should be preferred over the actual constructors. |
| /// |
| |
| static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { |
| return std::make_shared<Pattern>(tensorNum); |
| } |
| |
| #define IMPL_BINOP_PATTERN(OP, KIND) \ |
| LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern( \ |
| const std::shared_ptr<Pattern> &e0, \ |
| const std::shared_ptr<Pattern> &e1) { \ |
| return std::make_shared<Pattern>(KIND, e0, e1); \ |
| } |
| |
| FOREVERY_BINOP(IMPL_BINOP_PATTERN) |
| |
| #undef IMPL_BINOP_PATTERN |
| |
| class MergerTestBase : public ::testing::Test { |
| protected: |
| MergerTestBase(unsigned numTensors, unsigned numLoops) |
| : numTensors(numTensors), numLoops(numLoops), |
| merger(numTensors, numLoops) {} |
| |
| /// |
| /// Expression construction helpers. |
| /// |
| |
| unsigned tensor(unsigned tensor) { |
| return merger.addExp(Kind::kTensor, tensor); |
| } |
| |
| #define IMPL_BINOP_EXPR(OP, KIND) \ |
| LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \ |
| return merger.addExp(KIND, e0, e1); \ |
| } |
| |
| FOREVERY_BINOP(IMPL_BINOP_EXPR) |
| |
| #undef IMPL_BINOP_EXPR |
| |
| /// |
| /// Comparison helpers. |
| /// |
| |
| /// For readability of tests. |
| unsigned lat(unsigned lat) { return lat; } |
| |
| /// Returns true if a lattice point with an expression matching the given |
| /// pattern and bits matching the given bits is present in lattice points |
| /// [p, p+n) of lattice set s. This is useful for testing partial ordering |
| /// constraints between lattice points. We generally know how contiguous |
| /// groups of lattice points should be ordered with respect to other groups, |
| /// but there is no required ordering within groups. |
| /// If simple is true, then compare the lat.simple field instead to test the |
| /// result after optimization |
| bool latPointWithinRange(unsigned s, unsigned p, unsigned n, |
| const std::shared_ptr<Pattern> &pattern, |
| const BitVector &bits, bool simple) { |
| for (unsigned i = p; i < p + n; ++i) { |
| if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && |
| compareBits(s, i, bits, simple)) |
| return true; |
| } |
| return false; |
| } |
| |
| /// Wrapper over latPointWithinRange for readability of tests. |
| void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, |
| const std::shared_ptr<Pattern> &pattern, |
| const BitVector &bits, bool simple = false) { |
| EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple)); |
| } |
| |
| /// Wrapper over expectLatPointWithinRange for a single lat point. |
| void expectLatPoint(unsigned s, unsigned p, |
| const std::shared_ptr<Pattern> &pattern, |
| const BitVector &bits, bool simple = false) { |
| EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple)); |
| } |
| |
| /// Converts a vector of (loop, tensor) pairs to a bitvector with the |
| /// corresponding bits set. |
| BitVector |
| loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) { |
| BitVector testBits = BitVector(numTensors + 1, false); |
| for (auto l : loops) { |
| auto loop = std::get<0>(l); |
| auto tensor = std::get<1>(l); |
| testBits.set(numTensors * loop + tensor); |
| } |
| return testBits; |
| } |
| |
| /// Returns true if the bits of lattice point p in set s match the given bits. |
| /// If simple is true, then compare the lat.simple field instead to test the |
| /// result after optimization |
| bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) { |
| if (simple) |
| return merger.lat(merger.set(s)[p]).simple == bits; |
| return merger.lat(merger.set(s)[p]).bits == bits; |
| } |
| |
| /// Check that there are n lattice points in set s. |
| void expectNumLatPoints(unsigned s, unsigned n) { |
| EXPECT_THAT(merger.set(s).size(), n); |
| } |
| |
| /// Compares expressions for equality. Equality is defined recursively as: |
| /// - Operations are equal if they have the same kind and children. |
| /// - Leaf tensors are equal if they refer to the same tensor. |
| bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { |
| auto tensorExp = merger.exp(e); |
| if (tensorExp.kind != pattern->kind) |
| return false; |
| switch (tensorExp.kind) { |
| // Leaf. |
| case kTensor: |
| return tensorExp.tensor == pattern->tensorNum; |
| case kInvariant: |
| case kIndex: |
| llvm_unreachable("invariant not handled yet"); |
| // Unary operations. |
| case kAbsF: |
| case kAbsC: |
| case kCeilF: |
| case kFloorF: |
| case kSqrtF: |
| case kSqrtC: |
| case kExpm1F: |
| case kExpm1C: |
| case kLog1pF: |
| case kLog1pC: |
| case kSinF: |
| case kSinC: |
| case kTanhF: |
| case kTanhC: |
| case kNegF: |
| case kNegC: |
| case kNegI: |
| case kTruncF: |
| case kExtF: |
| case kCastFS: |
| case kCastFU: |
| case kCastSF: |
| case kCastUF: |
| case kCastS: |
| case kCastU: |
| case kCastIdx: |
| case kTruncI: |
| case kCIm: |
| case kCRe: |
| case kBitCast: |
| case kBinaryBranch: |
| case kUnary: |
| case kShlI: |
| case kBinary: |
| return compareExpression(tensorExp.children.e0, pattern->e0); |
| // Binary operations. |
| case kMulF: |
| case kMulC: |
| case kMulI: |
| case kDivF: |
| case kDivC: |
| case kDivS: |
| case kDivU: |
| case kAddF: |
| case kAddC: |
| case kAddI: |
| case kSubF: |
| case kSubC: |
| case kSubI: |
| case kAndI: |
| case kOrI: |
| case kXorI: |
| case kShrS: |
| case kShrU: |
| return compareExpression(tensorExp.children.e0, pattern->e0) && |
| compareExpression(tensorExp.children.e1, pattern->e1); |
| } |
| llvm_unreachable("unexpected kind"); |
| } |
| |
| unsigned numTensors; |
| unsigned numLoops; |
| Merger merger; |
| }; |
| |
| /// |
| /// Tests with all sparse inputs. |
| /// |
| |
| class MergerTest3T1L : public MergerTestBase { |
| protected: |
| // Our three tensors (two inputs, one output). |
| const unsigned t0 = 0, t1 = 1, t2 = 2; |
| |
| // Our single loop. |
| const unsigned l0 = 0; |
| |
| MergerTest3T1L() : MergerTestBase(3, 1) { |
| // Tensor 0: sparse input vector. |
| merger.addExp(Kind::kTensor, t0, -1u); |
| merger.setDim(t0, l0, Dim::kSparse); |
| |
| // Tensor 1: sparse input vector. |
| merger.addExp(Kind::kTensor, t1, -1u); |
| merger.setDim(t1, l0, Dim::kSparse); |
| |
| // Tensor 2: dense output vector. |
| merger.addExp(Kind::kTensor, t2, -1u); |
| merger.setDim(t2, l0, Dim::kDense); |
| } |
| }; |
| |
| class MergerTest4T1L : public MergerTestBase { |
| protected: |
| // Our four tensors (three inputs, one output). |
| const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; |
| |
| // Our single loop. |
| const unsigned l0 = 0; |
| |
| MergerTest4T1L() : MergerTestBase(4, 1) { |
| // Tensor 0: sparse input vector. |
| merger.addExp(Kind::kTensor, t0, -1u); |
| merger.setDim(t0, l0, Dim::kSparse); |
| |
| // Tensor 1: sparse input vector. |
| merger.addExp(Kind::kTensor, t1, -1u); |
| merger.setDim(t1, l0, Dim::kSparse); |
| |
| // Tensor 2: sparse input vector |
| merger.addExp(Kind::kTensor, t2, -1u); |
| merger.setDim(t2, l0, Dim::kSparse); |
| |
| // Tensor 3: dense output vector |
| merger.addExp(Kind::kTensor, t3, -1u); |
| merger.setDim(t3, l0, Dim::kDense); |
| } |
| }; |
| |
| /// |
| /// Tests with both sparse and dense input. |
| /// |
| |
| class MergerTest3T1LD : public MergerTestBase { |
| protected: |
| // Our three tensors (two inputs, one output). |
| const unsigned t0 = 0, t1 = 1, t2 = 2; |
| |
| // Our single loop. |
| const unsigned l0 = 0; |
| |
| MergerTest3T1LD() : MergerTestBase(3, 1) { |
| // Tensor 0: sparse input vector. |
| merger.addExp(Kind::kTensor, t0, -1u); |
| merger.setDim(t0, l0, Dim::kSparse); |
| |
| // Tensor 1: dense input vector. |
| merger.addExp(Kind::kTensor, t1, -1u); |
| merger.setDim(t1, l0, Dim::kDense); |
| |
| // Tensor 2: dense output vector. |
| merger.addExp(Kind::kTensor, t2, -1u); |
| merger.setDim(t2, l0, Dim::kDense); |
| } |
| }; |
| |
| } // namespace |
| |
| /// Vector addition (disjunction) of 2 vectors. i.e.; |
| /// a(i) = b(i) + c(i) |
| /// which should form the 3 lattice points |
| /// { |
| /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) |
| /// lat( i_00 / tensor_0 ) |
| /// lat( i_01 / tensor_1 ) |
| /// } |
| /// and after optimization, the lattice points do not change (as there is no |
| /// duplicated point and all input vectors are sparse vector). |
| /// { |
| /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) |
| /// lat( i_00 / tensor_0 ) |
| /// lat( i_01 / tensor_1 ) |
| /// } |
| #define IMPL_MERGER_TEST_DISJ(OP) \ |
| TEST_F(MergerTest3T1L, vector_##OP) { \ |
| auto e = OP##Expr(tensor(t0), tensor(t1)); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto s = merger.buildLattices(e, l0); \ |
| \ |
| expectNumLatPoints(s, 3); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ |
| \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 3); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}}), true); \ |
| expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \ |
| true); \ |
| expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \ |
| true); \ |
| } |
| |
| FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ) |
| |
| #undef IMPL_MERGER_TEST_DISJ |
| |
| /// Vector multiplication (conjunction) of 2 vectors, i.e.; |
| /// a(i) = b(i) * c(i) |
| /// which should form the single lattice point |
| /// { |
| /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) |
| /// } |
| #define IMPL_MERGER_TEST_CONJ(OP) \ |
| TEST_F(MergerTest3T1L, vector_##OP) { \ |
| auto e = OP##Expr(t0, t1); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto s = merger.buildLattices(e, l0); \ |
| \ |
| expectNumLatPoints(s, 1); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 1); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}}), true); \ |
| } |
| |
| FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ) |
| |
| #undef IMPL_MERGER_TEST_CONJ |
| |
| /// Vector multiplication (conjunction) then addition (disjunction), i.e.; |
| /// a(i) = b(i) * c(i) + d(i); |
| /// which should form |
| /// { |
| /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 ) |
| /// lat( i_00 i_01 / tensor_0 * tensor_1 |
| /// lat( i_02 / tensor_2 ) |
| /// } |
| #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \ |
| TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \ |
| auto em = CONJ##Expr(t0, t1); \ |
| auto e = DISJ##Expr(em, t2); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto p2 = tensorPattern(t2); \ |
| auto s = merger.buildLattices(e, l0); \ |
| \ |
| expectNumLatPoints(s, 3); \ |
| expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ |
| loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ |
| \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 3); \ |
| expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ |
| loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ |
| } |
| |
| FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ) |
| |
| #undef IMPL_MERGER_TEST_CONJ_DISJ |
| |
| /// Vector addition (disjunction) then addition (disjunction), i.e.; |
| /// a(i) = b(i) + c(i) + d(i) |
| /// which should form |
| /// { |
| /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 ) |
| /// lat( i_02 i_01 / tensor_2 + tensor_1 ) |
| /// lat( i_02 i_00 / tensor_2 + tensor_0 ) |
| /// lat( i_01 i_00 / tensor_1 + tensor_0 ) |
| /// lat( i_02 / tensor_2 ) |
| /// lat( i_01 / tensor_1 ) |
| /// lat( i_00 / tensor_0 ) |
| /// } |
| #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \ |
| TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \ |
| auto em = DISJ1##Expr(t0, t1); \ |
| auto e = DISJ2##Expr(em, t2); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto p2 = tensorPattern(t2); \ |
| auto s = merger.buildLattices(e, l0); \ |
| \ |
| expectNumLatPoints(s, 7); \ |
| expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ |
| loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ |
| loopsToBits({{l0, t1}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ |
| loopsToBits({{l0, t0}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ |
| \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 7); \ |
| expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ |
| loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ |
| loopsToBits({{l0, t1}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ |
| loopsToBits({{l0, t0}, {l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ |
| } |
| |
| FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ) |
| |
| #undef IMPL_MERGER_TEST_DISJ_DISJ |
| |
| /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.; |
| /// a(i) = b(i) * c(i) * d(i); |
| /// which should form |
| /// { |
| /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 ) |
| /// } |
| #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \ |
| TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \ |
| auto em = CONJ1##Expr(t0, t1); \ |
| auto e = CONJ2##Expr(em, t2); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto p2 = tensorPattern(t2); \ |
| auto s = merger.buildLattices(e, l0); \ |
| expectNumLatPoints(s, 1); \ |
| expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ |
| loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 1); \ |
| expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ |
| loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \ |
| } |
| |
| FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ) |
| |
| #undef IMPL_MERGER_TEST_CONJ_CONJ |
| |
| /// Vector addition (disjunction) of 2 vectors, i.e.; |
| /// a(i) = b(i) + c(i) |
| /// which should form the 3 lattice points |
| /// { |
| /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) |
| /// lat( i_00 / sparse_tensor_0 ) |
| /// lat( i_01 / dense_tensor_1 ) |
| /// } |
| /// which should be optimized to |
| /// { |
| /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton) |
| /// lat( i_01 / dense_tensor_0 ) (no sparse dimension) |
| /// } |
| /// |
| /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff |
| /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). |
| #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \ |
| TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ |
| auto e = OP##Expr(tensor(t0), tensor(t1)); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto s = merger.buildLattices(e, l0); \ |
| \ |
| expectNumLatPoints(s, 3); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ |
| expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ |
| \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 2); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}}), true); \ |
| expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \ |
| } |
| |
| FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ) |
| |
| #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ |
| |
| /// Vector multiplication (conjunction) of 2 vectors, i.e.: |
| /// a(i) = b(i) * c(i) |
| /// which should form the single lattice point |
| /// { |
| /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) ) |
| /// } |
| /// it should be optimized to |
| /// { |
| /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) ) |
| /// } |
| /// since i_01 is a dense dimension. |
| #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \ |
| TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ |
| auto e = OP##Expr(t0, t1); \ |
| auto p0 = tensorPattern(t0); \ |
| auto p1 = tensorPattern(t1); \ |
| auto s = merger.buildLattices(e, l0); \ |
| \ |
| expectNumLatPoints(s, 1); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ |
| loopsToBits({{l0, t0}, {l0, t1}})); \ |
| \ |
| s = merger.optimizeSet(s); \ |
| expectNumLatPoints(s, 1); \ |
| expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \ |
| true); \ |
| } |
| |
| FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ) |
| |
| #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ |
| |
| // TODO: mult-dim tests |
| |
| // restore warning status |
| #if defined(_MSC_VER) && !defined(__clang__) |
| #pragma warning(pop) |
| #endif |