//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the lowering and legalization of vector instructions to
// VVP_*layer SDNodes.
//
//===----------------------------------------------------------------------===//

#include "VECustomDAG.h"
#include "VEISelLowering.h"

using namespace llvm;

#define DEBUG_TYPE "ve-lower"

SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
                                              SelectionDAG &DAG) const {
  VECustomDAG CDAG(DAG, Op);
  SDValue AVL =
      CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
  SDValue A = Op->getOperand(0);
  SDValue B = Op->getOperand(1);
  SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
  SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
  SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
  SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
  unsigned Opc = Op.getOpcode();
  auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
  auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
  return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
}

SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
  // Can we represent this as a VVP node.
  const unsigned Opcode = Op->getOpcode();
  auto VVPOpcodeOpt = getVVPOpcode(Opcode);
  if (!VVPOpcodeOpt)
    return SDValue();
  unsigned VVPOpcode = VVPOpcodeOpt.value();
  const bool FromVP = ISD::isVPOpcode(Opcode);

  // The representative and legalized vector type of this operation.
  VECustomDAG CDAG(DAG, Op);
  // Dispatch to complex lowering functions.
  switch (VVPOpcode) {
  case VEISD::VVP_LOAD:
  case VEISD::VVP_STORE:
    return lowerVVP_LOAD_STORE(Op, CDAG);
  case VEISD::VVP_GATHER:
  case VEISD::VVP_SCATTER:
    return lowerVVP_GATHER_SCATTER(Op, CDAG);
  }

  EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
  EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
  auto Packing = getTypePacking(LegalVecVT.getSimpleVT());

  SDValue AVL;
  SDValue Mask;

  if (FromVP) {
    // All upstream VP SDNodes always have a mask and avl.
    auto MaskIdx = ISD::getVPMaskIdx(Opcode);
    auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
    if (MaskIdx)
      Mask = Op->getOperand(*MaskIdx);
    if (AVLIdx)
      AVL = Op->getOperand(*AVLIdx);
  }

  // Materialize default mask and avl.
  if (!AVL)
    AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
  if (!Mask)
    Mask = CDAG.getConstantMask(Packing, true);

  assert(LegalVecVT.isSimple());
  if (isVVPUnaryOp(VVPOpcode))
    return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
  if (isVVPBinaryOp(VVPOpcode))
    return CDAG.getNode(VVPOpcode, LegalVecVT,
                        {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
  if (isVVPReductionOp(VVPOpcode)) {
    auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
    SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
    SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
    return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
                                       VectorV, Mask, AVL, Op->getFlags());
  }

  switch (VVPOpcode) {
  default:
    llvm_unreachable("lowerToVVP called for unexpected SDNode.");
  case VEISD::VVP_FFMA: {
    // VE has a swizzled operand order in FMA (compared to LLVM IR and
    // SDNodes).
    auto X = Op->getOperand(2);
    auto Y = Op->getOperand(0);
    auto Z = Op->getOperand(1);
    return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
  }
  case VEISD::VVP_SELECT: {
    auto Mask = Op->getOperand(0);
    auto OnTrue = Op->getOperand(1);
    auto OnFalse = Op->getOperand(2);
    return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
  }
  case VEISD::VVP_SETCC: {
    EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
    auto LHS = Op->getOperand(0);
    auto RHS = Op->getOperand(1);
    auto Pred = Op->getOperand(2);
    return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
  }
  }
}

SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
                                              VECustomDAG &CDAG) const {
  auto VVPOpc = *getVVPOpcode(Op->getOpcode());
  const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);

  // Shares.
  SDValue BasePtr = getMemoryPtr(Op);
  SDValue Mask = getNodeMask(Op);
  SDValue Chain = getNodeChain(Op);
  SDValue AVL = getNodeAVL(Op);
  // Store specific.
  SDValue Data = getStoredValue(Op);
  // Load specific.
  SDValue PassThru = getNodePassthru(Op);

  SDValue StrideV = getLoadStoreStride(Op, CDAG);

  auto DataVT = *getIdiomaticVectorType(Op.getNode());
  auto Packing = getTypePacking(DataVT);

  // TODO: Infer lower AVL from mask.
  if (!AVL)
    AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);

  // Default to the all-true mask.
  if (!Mask)
    Mask = CDAG.getConstantMask(Packing, true);

  if (IsLoad) {
    MVT LegalDataVT = getLegalVectorType(
        Packing, DataVT.getVectorElementType().getSimpleVT());

    auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
                                 {Chain, BasePtr, StrideV, Mask, AVL});

    if (!PassThru || PassThru->isUndef())
      return NewLoadV;

    // Convert passthru to an explicit select node.
    SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
                                 {NewLoadV, PassThru, Mask, AVL});
    SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);

    // Merge them back into one node.
    return CDAG.getMergeValues({DataV, NewLoadChainV});
  }

  // VVP_STORE
  assert(VVPOpc == VEISD::VVP_STORE);
  return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
                      {Chain, Data, BasePtr, StrideV, Mask, AVL});
}

SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
                                               VECustomDAG &CDAG) const {
  auto VVPOC = *getVVPOpcode(Op.getOpcode());
  assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));

  MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
  assert(getTypePacking(DataVT) == Packing::Dense &&
         "Can only split packed load/store");
  MVT SplitDataVT = splitVectorType(DataVT);

  assert(!getNodePassthru(Op) &&
         "Should have been folded in lowering to VVP layer");

  // Analyze the operation
  SDValue PackedMask = getNodeMask(Op);
  SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
  SDValue PackPtr = getMemoryPtr(Op);
  SDValue PackData = getStoredValue(Op);
  SDValue PackStride = getLoadStoreStride(Op, CDAG);

  unsigned ChainResIdx = PackData ? 0 : 1;

  SDValue PartOps[2];

  SDValue UpperPartAVL; // we will use this for packing things back together
  for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
    // VP ops already have an explicit mask and AVL. When expanding from non-VP
    // attach those additional inputs here.
    auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);

    // Keep track of the (higher) lvl.
    if (Part == PackElem::Hi)
      UpperPartAVL = SplitTM.AVL;

    // Attach non-predicating value operands
    SmallVector<SDValue, 4> OpVec;

    // Chain
    OpVec.push_back(getNodeChain(Op));

    // Data
    if (PackData) {
      SDValue PartData =
          CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
      OpVec.push_back(PartData);
    }

    // Ptr & Stride
    // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
    // Stride info
    // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
    OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
    OpVec.push_back(CDAG.getSplitPtrStride(PackStride));

    // Add predicating args and generate part node
    OpVec.push_back(SplitTM.Mask);
    OpVec.push_back(SplitTM.AVL);

    if (PackData) {
      // Store
      PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
    } else {
      // Load
      PartOps[(int)Part] =
          CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
    }
  }

  // Merge the chains
  SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
  SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
  SDValue FusedChains =
      CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});

  // Chain only [store]
  if (PackData)
    return FusedChains;

  // Re-pack into full packed vector result
  MVT PackedVT =
      getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
  SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
                                    PartOps[(int)PackElem::Hi], UpperPartAVL);

  return CDAG.getMergeValues({PackedVals, FusedChains});
}

SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
                                                  VECustomDAG &CDAG) const {
  EVT DataVT = *getIdiomaticVectorType(Op.getNode());
  auto Packing = getTypePacking(DataVT);
  MVT LegalDataVT =
      getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());

  SDValue AVL = getAnnotatedNodeAVL(Op).first;
  SDValue Index = getGatherScatterIndex(Op);
  SDValue BasePtr = getMemoryPtr(Op);
  SDValue Mask = getNodeMask(Op);
  SDValue Chain = getNodeChain(Op);
  SDValue Scale = getGatherScatterScale(Op);
  SDValue PassThru = getNodePassthru(Op);
  SDValue StoredValue = getStoredValue(Op);
  if (PassThru && PassThru->isUndef())
    PassThru = SDValue();

  bool IsScatter = (bool)StoredValue;

  // TODO: Infer lower AVL from mask.
  if (!AVL)
    AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);

  // Default to the all-true mask.
  if (!Mask)
    Mask = CDAG.getConstantMask(Packing, true);

  SDValue AddressVec =
      CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
  if (IsScatter)
    return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
                        {Chain, StoredValue, AddressVec, Mask, AVL});

  // Gather.
  SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
                                  {Chain, AddressVec, Mask, AVL});

  if (!PassThru)
    return NewLoadV;

  // TODO: Use vvp_select
  SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
                               {NewLoadV, PassThru, Mask, AVL});
  SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
  return CDAG.getMergeValues({DataV, NewLoadChainV});
}

SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
                                                      VECustomDAG &CDAG) const {
  LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
  MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();

  // TODO: Recognize packable load,store.
  if (isPackedVectorType(DataVT))
    return splitPackedLoadStore(Op, CDAG);

  return legalizePackedAVL(Op, CDAG);
}

SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
                                                   SelectionDAG &DAG) const {
  LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
  VECustomDAG CDAG(DAG, Op);

  // Dispatch to specialized legalization functions.
  switch (Op->getOpcode()) {
  case VEISD::VVP_LOAD:
  case VEISD::VVP_STORE:
    return legalizeInternalLoadStoreOp(Op, CDAG);
  }

  EVT IdiomVT = Op.getValueType();
  if (isPackedVectorType(IdiomVT) &&
      !supportsPackedMode(Op.getOpcode(), IdiomVT))
    return splitVectorOp(Op, CDAG);

  // TODO: Implement odd/even splitting.
  return legalizePackedAVL(Op, CDAG);
}

SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
  MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());

  auto AVLPos = getAVLPos(Op->getOpcode());
  auto MaskPos = getMaskPos(Op->getOpcode());

  SDValue PackedMask = getNodeMask(Op);
  auto AVLPair = getAnnotatedNodeAVL(Op);
  SDValue PackedAVL = AVLPair.first;
  assert(!AVLPair.second && "Expecting non pack-legalized oepration");

  // request the parts
  SDValue PartOps[2];

  SDValue UpperPartAVL; // we will use this for packing things back together
  for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
    // VP ops already have an explicit mask and AVL. When expanding from non-VP
    // attach those additional inputs here.
    auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);

    if (Part == PackElem::Hi)
      UpperPartAVL = SplitTM.AVL;

    // Attach non-predicating value operands
    SmallVector<SDValue, 4> OpVec;
    for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
      if (AVLPos && ((int)i) == *AVLPos)
        continue;
      if (MaskPos && ((int)i) == *MaskPos)
        continue;

      // Value operand
      auto PackedOperand = Op.getOperand(i);
      auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
      SDValue PartV =
          CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
      OpVec.push_back(PartV);
    }

    // Add predicating args and generate part node.
    OpVec.push_back(SplitTM.Mask);
    OpVec.push_back(SplitTM.AVL);
    // Emit legal VVP nodes.
    PartOps[(int)Part] =
        CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
  }

  // Re-package vectors.
  return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
                      PartOps[(int)PackElem::Hi], UpperPartAVL);
}

SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
                                            VECustomDAG &CDAG) const {
  LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
  // Only required for VEC and VVP ops.
  if (!isVVPOrVEC(Op->getOpcode()))
    return Op;

  // Operation already has a legal AVL.
  auto AVL = getNodeAVL(Op);
  if (isLegalAVL(AVL))
    return Op;

  // Half and round up EVL for 32bit element types.
  SDValue LegalAVL = AVL;
  MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
  if (isPackedVectorType(IdiomVT)) {
    assert(maySafelyIgnoreMask(Op) &&
           "TODO Shift predication from EVL into Mask");

    if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
      LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
    } else {
      auto ConstOne = CDAG.getConstant(1, MVT::i32);
      auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
      LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
    }
  }

  SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);

  // Copy the operand list.
  int NumOp = Op->getNumOperands();
  auto AVLPos = getAVLPos(Op->getOpcode());
  std::vector<SDValue> FixedOperands;
  for (int i = 0; i < NumOp; ++i) {
    if (AVLPos && (i == *AVLPos)) {
      FixedOperands.push_back(AnnotatedLegalAVL);
      continue;
    }
    FixedOperands.push_back(Op->getOperand(i));
  }

  // Clone the operation with fixed operands.
  auto Flags = Op->getFlags();
  SDValue NewN =
      CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
  return NewN;
}
