Trie
diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index b80888c..aac0500 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h
@@ -62,8 +62,17 @@ class TypeTree; +class Trie : std::map<int, Trie> { +public: + int depth; + ConcreteType value; + Trie(ConcreteType CT) : value(CT), depth(0) {}; + Trie() : Trie(BaseType::Unknown); +}; + typedef std::shared_ptr<const TypeTree> TypeResult; -typedef std::map<const std::vector<int>, ConcreteType> ConcreteTypeMapType; +//typedef std::map<const std::vector<int>, ConcreteType> ConcreteTypeMapType; +typedef Trie<int, ConcreteType> ConcreteTypeMapType; typedef std::map<const std::vector<int>, const TypeResult> TypeTreeMapType; /// Class representing the underlying types of values as @@ -76,10 +85,7 @@ public: TypeTree() {} - TypeTree(ConcreteType dat) { - if (dat != ConcreteType(BaseType::Unknown)) { - mapping.insert(std::pair<const std::vector<int>, ConcreteType>({}, dat)); - } + TypeTree(ConcreteType dat) : mapping(dat) { } /// Utility helper to lookup the mapping @@ -88,25 +94,36 @@ /// Lookup the underlying ConcreteType at a given offset sequence /// or Unknown if none exists ConcreteType operator[](const std::vector<int> Seq) const { - auto Found = mapping.find(Seq); - if (Found != mapping.end()) { - return Found->second; - } - for (const auto &pair : mapping) { - if (pair.first.size() != Seq.size()) - continue; - bool Match = true; - for (unsigned i = 0, size = pair.first.size(); i < size; ++i) { - if (pair.first[i] == -1) - continue; - if (pair.first[i] != Seq[i]) { - Match = false; - break; + if (Len == 0) return mapping.value; + + std::vector<Trie*> todo[2]; + todo[0].push_back(&mapping); + int parity = 0; + for (size_t i = 0, Len = Seq.size(); i < Len - 1; ++i) { + for (auto prev : todo[parity]) { + auto found = prev.find(-1); + if (found != mapping.end()) + todo[1 - parity].push_back(&found->second); + if (Seq[i] != -1) { + found = prev.find(Seq[i]); + if (found != mapping.end()) + todo[1 - parity].push_back(&found->second); } } - if (!Match) - continue; - return pair.second; + todo[parity].clear(); + parity = 1 - parity; + } + + size_t i = Len - 1; + for (auto prev : todo[parity]) { + auto found = prev.find(-1); + if (found != mapping.end() && found->second.value != BaseType::Unknown) + return found->second.value; + if (Seq[i] != -1) { + found = prev.find(Seq[i]); + if (found != mapping.end() && found->second.value != BaseType::Unknown) + return found->second.value; + } } return BaseType::Unknown; } @@ -114,22 +131,31 @@ // Return true if this type tree is fully known (i.e. there // is no more information which could be added). bool IsFullyDetermined() const { - std::vector<int> offsets = {-1}; - while (1) { - auto found = mapping.find(offsets); - if (found == mapping.end()) - return false; - if (found->second != BaseType::Pointer) - return true; - offsets.push_back(-1); + Trie* T = &mapping; + while (true) { + auto found = T->find(-1); + if (found != T->end()) return false; + T = &found->second; + if (T->value != BaseType::Unknown) return true; } } /// Return if changed bool insert(const std::vector<int> Seq, ConcreteType CT, bool intsAreLegalSubPointer = false) { - bool changed = false; - if (Seq.size() > 0) { + size_t SeqSize = Seq.size(); + if (SeqSize > EnzymeMaxTypeDepth) { + if (EnzymeTypeWarning) + llvm::errs() << "not handling more than " << EnzymeMaxTypeDepth + << " pointer lookups deep dt:" << str() + << " adding v: " << to_string(Seq) << ": " << CT.str() + << "\n"; + return false; + } + if (SeqSize == 0) { + mapping.value = CT; + return true; + } // check types at lower pointer offsets are either pointer or // anything. @@ -394,29 +420,9 @@ /// Prepend an offset to all mappings TypeTree Only(int Off) const { TypeTree Result; - Result.minIndices.reserve(1 + minIndices.size()); - Result.minIndices.push_back(Off); - for (auto midx : minIndices) - Result.minIndices.push_back(midx); - - if (Result.minIndices.size() > EnzymeMaxTypeDepth) { - Result.minIndices.pop_back(); - if (EnzymeTypeWarning) - llvm::errs() << "not handling more than " << EnzymeMaxTypeDepth - << " pointer lookups deep dt:" << str() << " only(" << Off - << "): " << str() << "\n"; - } - - for (const auto &pair : mapping) { - if (pair.first.size() == EnzymeMaxTypeDepth) - continue; - std::vector<int> Vec; - Vec.reserve(pair.first.size() + 1); - Vec.push_back(Off); - for (auto Val : pair.first) - Vec.push_back(Val); - Result.mapping.insert( - std::pair<const std::vector<int>, ConcreteType>(Vec, pair.second)); + if (mapping.depth > 0 || mapping.value != BaseType::Unknown) { + Result.mapping.emplace(Off, mapping); + Result.depth = mapping.depth + 1; } return Result; } @@ -425,6 +431,7 @@ TypeTree Data0() const { TypeTree Result; + for (const auto &pair : mapping) { if (pair.first.size() == 0) { llvm::errs() << str() << "\n"; @@ -456,9 +463,7 @@ /// Optimized version of Data0()[{}] ConcreteType Inner0() const { - ConcreteType CT = operator[]({-1}); - CT |= operator[]({0}); - return CT; + return operator[]({0}); } /// Remove any mappings in the range [start, end) or [len, inf)