use populate_overwritten_args
diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 115fbd1..e1e20e9 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp
@@ -1948,40 +1948,8 @@ constants.push_back(DIFFE_TYPE::DUP_ARG); std::vector<bool> overwritten_args; - FnTypeInfo type_args(newFunc); - for (auto &a : type_args.Function->args()) { - overwritten_args.push_back( - !(diffeMode == DerivativeMode::ReverseModeCombined)); - TypeTree dt; - if (a.getType()->isFPOrFPVectorTy()) { - dt = ConcreteType(a.getType()->getScalarType()); - } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 15 - if (a.getContext().supportsTypedPointers()) { -#endif - auto et = a.getType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); - } -#if LLVM_VERSION_MAJOR >= 15 - } -#endif - dt.insert({}, BaseType::Pointer); - } else if (a.getType()->isIntOrIntVectorTy()) { - dt = ConcreteType(BaseType::Integer); - } - type_args.Arguments.insert( - std::pair<Argument *, TypeTree>(&a, dt.Only(-1, nullptr))); - // TODO note that here we do NOT propagate constants in type info (and - // should consider whether we should) - type_args.KnownValues.insert( - std::pair<Argument *, std::set<int64_t>>(&a, {})); - } - TypeAnalysis TA(Logic.PPC.FAM); - type_args = TA.analyzeFunction(type_args).getAnalyzedTypeInfo(); + FnTypeInfo type_args = populate_overwritten_args(TA, newFunc, diffeMode, overwritten_args); auto diffeFunc = Logic.CreatePrimalAndGradient( (ReverseCacheKey){.todiff = newFunc,