//=======- PtrTypesSemantics.cpp ---------------------------------*- C++ -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PtrTypesSemantics.h"
#include "ASTUtils.h"
#include "clang/AST/Attr.h"
#include "clang/AST/CXXInheritance.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Analysis/DomainSpecific/CocoaConventions.h"
#include <optional>

using namespace clang;

namespace {

bool hasPublicMethodInBaseClass(const CXXRecordDecl *R, StringRef NameToMatch) {
  assert(R);
  assert(R->hasDefinition());

  for (const CXXMethodDecl *MD : R->methods()) {
    const auto MethodName = safeGetName(MD);
    if (MethodName == NameToMatch && MD->getAccess() == AS_public)
      return true;
  }
  return false;
}

} // namespace

namespace clang {

std::optional<const clang::CXXRecordDecl *>
hasPublicMethodInBase(const CXXBaseSpecifier *Base, StringRef NameToMatch) {
  assert(Base);

  const Type *T = Base->getType().getTypePtrOrNull();
  if (!T)
    return std::nullopt;

  const CXXRecordDecl *R = T->getAsCXXRecordDecl();
  if (!R) {
    auto CT = Base->getType().getCanonicalType();
    if (auto *TST = dyn_cast<TemplateSpecializationType>(CT)) {
      auto TmplName = TST->getTemplateName();
      if (!TmplName.isNull()) {
        if (auto *TD = TmplName.getAsTemplateDecl())
          R = dyn_cast_or_null<CXXRecordDecl>(TD->getTemplatedDecl());
      }
    }
    if (!R)
      return std::nullopt;
  }
  if (!R->hasDefinition())
    return std::nullopt;

  return hasPublicMethodInBaseClass(R, NameToMatch) ? R : nullptr;
}

std::optional<bool> isSmartPtrCompatible(const CXXRecordDecl *R,
                                         StringRef IncMethodName,
                                         StringRef DecMethodName) {
  assert(R);

  R = R->getDefinition();
  if (!R)
    return std::nullopt;

  bool hasRef = hasPublicMethodInBaseClass(R, IncMethodName);
  bool hasDeref = hasPublicMethodInBaseClass(R, DecMethodName);
  if (hasRef && hasDeref)
    return true;

  CXXBasePaths Paths;
  Paths.setOrigin(const_cast<CXXRecordDecl *>(R));

  bool AnyInconclusiveBase = false;
  const auto hasPublicRefInBase = [&](const CXXBaseSpecifier *Base,
                                      CXXBasePath &) {
    auto hasRefInBase = clang::hasPublicMethodInBase(Base, IncMethodName);
    if (!hasRefInBase) {
      AnyInconclusiveBase = true;
      return false;
    }
    return (*hasRefInBase) != nullptr;
  };

  hasRef = hasRef || R->lookupInBases(hasPublicRefInBase, Paths,
                                      /*LookupInDependent =*/true);
  if (AnyInconclusiveBase)
    return std::nullopt;

  Paths.clear();
  const auto hasPublicDerefInBase = [&](const CXXBaseSpecifier *Base,
                                        CXXBasePath &) {
    auto hasDerefInBase = clang::hasPublicMethodInBase(Base, DecMethodName);
    if (!hasDerefInBase) {
      AnyInconclusiveBase = true;
      return false;
    }
    return (*hasDerefInBase) != nullptr;
  };
  hasDeref = hasDeref || R->lookupInBases(hasPublicDerefInBase, Paths,
                                          /*LookupInDependent =*/true);
  if (AnyInconclusiveBase)
    return std::nullopt;

  return hasRef && hasDeref;
}

std::optional<bool> isRefCountable(const clang::CXXRecordDecl *R) {
  return isSmartPtrCompatible(R, "ref", "deref");
}

std::optional<bool> isCheckedPtrCapable(const clang::CXXRecordDecl *R) {
  return isSmartPtrCompatible(R, "incrementCheckedPtrCount",
                              "decrementCheckedPtrCount");
}

bool isRefType(const std::string &Name) {
  return Name == "Ref" || Name == "RefAllowingPartiallyDestroyed" ||
         Name == "RefPtr" || Name == "RefPtrAllowingPartiallyDestroyed";
}

bool isRetainPtrOrOSPtr(const std::string &Name) {
  return Name == "RetainPtr" || Name == "RetainPtrArc" ||
         Name == "OSObjectPtr" || Name == "OSObjectPtrArc";
}

bool isCheckedPtr(const std::string &Name) {
  return Name == "CheckedPtr" || Name == "CheckedRef";
}

bool isOwnerPtr(const std::string &Name) {
  return isRefType(Name) || isCheckedPtr(Name) || Name == "unique_ptr" ||
         Name == "UniqueRef" || Name == "LazyUniqueRef";
}

bool isSmartPtrClass(const std::string &Name) {
  return isRefType(Name) || isCheckedPtr(Name) || isRetainPtrOrOSPtr(Name) ||
         Name == "WeakPtr" || Name == "WeakPtrFactory" ||
         Name == "WeakPtrFactoryWithBitField" || Name == "WeakPtrImplBase" ||
         Name == "WeakPtrImplBaseSingleThread" || Name == "ThreadSafeWeakPtr" ||
         Name == "ThreadSafeWeakOrStrongPtr" ||
         Name == "ThreadSafeWeakPtrControlBlock" ||
         Name == "ThreadSafeRefCountedAndCanMakeThreadSafeWeakPtr";
}

bool isCtorOfRefCounted(const clang::FunctionDecl *F) {
  assert(F);
  const std::string &FunctionName = safeGetName(F);

  return isRefType(FunctionName) || FunctionName == "adoptRef" ||
         FunctionName == "UniqueRef" || FunctionName == "makeUniqueRef" ||
         FunctionName == "makeUniqueRefWithoutFastMallocCheck"

         || FunctionName == "String" || FunctionName == "AtomString" ||
         FunctionName == "UniqueString"
         // FIXME: Implement as attribute.
         || FunctionName == "Identifier";
}

bool isCtorOfCheckedPtr(const clang::FunctionDecl *F) {
  assert(F);
  return isCheckedPtr(safeGetName(F));
}

bool isCtorOfRetainPtrOrOSPtr(const clang::FunctionDecl *F) {
  const std::string &FunctionName = safeGetName(F);
  return FunctionName == "RetainPtr" || FunctionName == "adoptNS" ||
         FunctionName == "adoptCF" || FunctionName == "retainPtr" ||
         FunctionName == "RetainPtrArc" || FunctionName == "adoptNSArc" ||
         FunctionName == "adoptOSObject" || FunctionName == "adoptOSObjectArc";
}

bool isCtorOfSafePtr(const clang::FunctionDecl *F) {
  return isCtorOfRefCounted(F) || isCtorOfCheckedPtr(F) ||
         isCtorOfRetainPtrOrOSPtr(F);
}

bool isStdOrWTFMove(const clang::FunctionDecl *F) {
  auto FnName = safeGetName(F);
  auto *Namespace = F->getParent();
  if (!Namespace)
    return false;
  auto *TUDeck = Namespace->getParent();
  if (!isa_and_nonnull<TranslationUnitDecl>(TUDeck))
    return false;
  auto NsName = safeGetName(Namespace);
  return (NsName == "WTF" || NsName == "std") && FnName == "move";
}

template <typename Predicate>
static bool isPtrOfType(const clang::QualType T, Predicate Pred) {
  QualType type = T;
  while (!type.isNull()) {
    if (auto *SpecialT = type->getAs<TemplateSpecializationType>()) {
      auto *Decl = SpecialT->getTemplateName().getAsTemplateDecl();
      return Decl && Pred(Decl->getNameAsString());
    } else if (auto *DTS = type->getAs<DeducedTemplateSpecializationType>()) {
      auto *Decl = DTS->getTemplateName().getAsTemplateDecl();
      return Decl && Pred(Decl->getNameAsString());
    } else
      break;
  }
  return false;
}

bool isRefOrCheckedPtrType(const clang::QualType T) {
  return isPtrOfType(
      T, [](auto Name) { return isRefType(Name) || isCheckedPtr(Name); });
}

bool isRetainPtrOrOSPtrType(const clang::QualType T) {
  return isPtrOfType(T, [](auto Name) { return isRetainPtrOrOSPtr(Name); });
}

bool isOwnerPtrType(const clang::QualType T) {
  return isPtrOfType(T, [](auto Name) { return isOwnerPtr(Name); });
}

std::optional<bool> isUncounted(const QualType T) {
  if (auto *Subst = dyn_cast<SubstTemplateTypeParmType>(T)) {
    if (auto *Decl = Subst->getAssociatedDecl()) {
      if (isRefType(safeGetName(Decl)))
        return false;
    }
  }
  return isUncounted(T->getAsCXXRecordDecl());
}

std::optional<bool> isUnchecked(const QualType T) {
  if (auto *Subst = dyn_cast<SubstTemplateTypeParmType>(T)) {
    if (auto *Decl = Subst->getAssociatedDecl()) {
      if (isCheckedPtr(safeGetName(Decl)))
        return false;
    }
  }
  return isUnchecked(T->getAsCXXRecordDecl());
}

void RetainTypeChecker::visitTranslationUnitDecl(
    const TranslationUnitDecl *TUD) {
  IsARCEnabled = TUD->getLangOpts().ObjCAutoRefCount;
  DefaultSynthProperties = TUD->getLangOpts().ObjCDefaultSynthProperties;
}

void RetainTypeChecker::visitTypedef(const TypedefDecl *TD) {
  auto QT = TD->getUnderlyingType();
  if (!QT->isPointerType())
    return;

  auto PointeeQT = QT->getPointeeType();
  const RecordType *RT = PointeeQT->getAsCanonical<RecordType>();
  if (!RT) {
    if (TD->hasAttr<ObjCBridgeAttr>() || TD->hasAttr<ObjCBridgeMutableAttr>()) {
      RecordlessTypes.insert(TD->getASTContext()
                                 .getTypedefType(ElaboratedTypeKeyword::None,
                                                 /*Qualifier=*/std::nullopt, TD)
                                 .getTypePtr());
    }
    return;
  }

  for (auto *Redecl : RT->getDecl()->getMostRecentDecl()->redecls()) {
    if (Redecl->getAttr<ObjCBridgeAttr>() ||
        Redecl->getAttr<ObjCBridgeMutableAttr>()) {
      CFPointees.insert(RT);
      return;
    }
  }
}

bool RetainTypeChecker::isUnretained(const QualType QT, bool ignoreARC) {
  if (ento::cocoa::isCocoaObjectRef(QT) && (!IsARCEnabled || ignoreARC))
    return true;
  if (auto *RT = dyn_cast_or_null<RecordType>(
          QT.getCanonicalType()->getPointeeType().getTypePtrOrNull()))
    return CFPointees.contains(RT);
  return RecordlessTypes.contains(QT.getTypePtr());
}

std::optional<bool> isUncounted(const CXXRecordDecl* Class)
{
  // Keep isRefCounted first as it's cheaper.
  if (!Class || isRefCounted(Class))
    return false;

  std::optional<bool> IsRefCountable = isRefCountable(Class);
  if (!IsRefCountable)
    return std::nullopt;

  return (*IsRefCountable);
}

std::optional<bool> isUnchecked(const CXXRecordDecl *Class) {
  if (!Class || isCheckedPtr(Class))
    return false; // Cheaper than below
  return isCheckedPtrCapable(Class);
}

std::optional<bool> isUncountedPtr(const QualType T) {
  if (T->isPointerType() || T->isReferenceType()) {
    if (auto *CXXRD = T->getPointeeCXXRecordDecl())
      return isUncounted(CXXRD);
  }
  return false;
}

std::optional<bool> isUncheckedPtr(const QualType T) {
  if (T->isPointerType() || T->isReferenceType()) {
    if (auto *CXXRD = T->getPointeeCXXRecordDecl())
      return isUnchecked(CXXRD);
  }
  return false;
}

std::optional<bool> isGetterOfSafePtr(const CXXMethodDecl *M) {
  assert(M);

  if (isa<CXXMethodDecl>(M)) {
    const CXXRecordDecl *calleeMethodsClass = M->getParent();
    auto className = safeGetName(calleeMethodsClass);
    auto method = safeGetName(M);

    if (isCheckedPtr(className) && (method == "get" || method == "ptr"))
      return true;

    if ((isRefType(className) && (method == "get" || method == "ptr")) ||
        ((className == "String" || className == "AtomString" ||
          className == "AtomStringImpl" || className == "UniqueString" ||
          className == "UniqueStringImpl" || className == "Identifier") &&
         method == "impl"))
      return true;

    if (isRetainPtrOrOSPtr(className) && method == "get")
      return true;

    // Ref<T> -> T conversion
    // FIXME: Currently allowing any Ref<T> -> whatever cast.
    if (isRefType(className)) {
      if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(M)) {
        auto QT = maybeRefToRawOperator->getConversionType();
        auto *T = QT.getTypePtrOrNull();
        return T && (T->isPointerType() || T->isReferenceType());
      }
    }

    if (isCheckedPtr(className)) {
      if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(M)) {
        auto QT = maybeRefToRawOperator->getConversionType();
        auto *T = QT.getTypePtrOrNull();
        return T && (T->isPointerType() || T->isReferenceType());
      }
    }

    if (isRetainPtrOrOSPtr(className)) {
      if (auto *maybeRefToRawOperator = dyn_cast<CXXConversionDecl>(M)) {
        auto QT = maybeRefToRawOperator->getConversionType();
        auto *T = QT.getTypePtrOrNull();
        return T && (T->isPointerType() || T->isReferenceType() ||
                     T->isObjCObjectPointerType());
      }
    }
  }
  return false;
}

bool isRefCounted(const CXXRecordDecl *R) {
  assert(R);
  if (auto *TmplR = R->getTemplateInstantiationPattern()) {
    // FIXME: String/AtomString/UniqueString
    const auto &ClassName = safeGetName(TmplR);
    return isRefType(ClassName);
  }
  return false;
}

bool isCheckedPtr(const CXXRecordDecl *R) {
  assert(R);
  if (auto *TmplR = R->getTemplateInstantiationPattern()) {
    const auto &ClassName = safeGetName(TmplR);
    return isCheckedPtr(ClassName);
  }
  return false;
}

bool isRetainPtrOrOSPtr(const CXXRecordDecl *R) {
  assert(R);
  if (auto *TmplR = R->getTemplateInstantiationPattern())
    return isRetainPtrOrOSPtr(safeGetName(TmplR));
  return false;
}

bool isSmartPtr(const CXXRecordDecl *R) {
  assert(R);
  if (auto *TmplR = R->getTemplateInstantiationPattern())
    return isSmartPtrClass(safeGetName(TmplR));
  return false;
}

enum class WebKitAnnotation : uint8_t {
  None,
  PointerConversion,
  NoDelete,
};

static WebKitAnnotation typeAnnotationForReturnType(const FunctionDecl *FD) {
  auto RetType = FD->getReturnType();
  auto *Type = RetType.getTypePtrOrNull();
  if (auto *MacroQualified = dyn_cast_or_null<MacroQualifiedType>(Type))
    Type = MacroQualified->desugar().getTypePtrOrNull();
  auto *Attr = dyn_cast_or_null<AttributedType>(Type);
  if (!Attr)
    return WebKitAnnotation::None;
  auto *AnnotateType = dyn_cast_or_null<AnnotateTypeAttr>(Attr->getAttr());
  if (!AnnotateType)
    return WebKitAnnotation::None;
  auto Annotation = AnnotateType->getAnnotation();
  if (Annotation == "webkit.pointerconversion")
    return WebKitAnnotation::PointerConversion;
  if (Annotation == "webkit.nodelete")
    return WebKitAnnotation::NoDelete;
  return WebKitAnnotation::None;
}

bool isPtrConversion(const FunctionDecl *F) {
  assert(F);
  if (isCtorOfRefCounted(F))
    return true;

  // FIXME: check # of params == 1
  const auto FunctionName = safeGetName(F);
  if (FunctionName == "getPtr" || FunctionName == "WeakPtr" ||
      FunctionName == "dynamicDowncast" || FunctionName == "downcast" ||
      FunctionName == "checkedDowncast" || FunctionName == "bit_cast" ||
      FunctionName == "uncheckedDowncast" || FunctionName == "bitwise_cast" ||
      FunctionName == "bridge_cast" || FunctionName == "bridge_id_cast" ||
      FunctionName == "dynamic_cf_cast" || FunctionName == "checked_cf_cast" ||
      FunctionName == "dynamic_objc_cast" ||
      FunctionName == "checked_objc_cast")
    return true;

  if (typeAnnotationForReturnType(F) == WebKitAnnotation::PointerConversion)
    return true;

  return false;
}

static bool isNoDeleteFunctionDecl(const FunctionDecl *F) {
  return typeAnnotationForReturnType(F) == WebKitAnnotation::NoDelete;
}

bool isNoDeleteFunction(const FunctionDecl *F) {
  if (llvm::any_of(F->redecls(), isNoDeleteFunctionDecl))
    return true;

  const auto *MD = dyn_cast<CXXMethodDecl>(F);
  if (!MD || !MD->isVirtual())
    return false;

  auto Overriders = llvm::to_vector(MD->overridden_methods());
  while (!Overriders.empty()) {
    const auto *Fn = Overriders.pop_back_val();
    llvm::append_range(Overriders, Fn->overridden_methods());
    if (isNoDeleteFunctionDecl(Fn))
      return true;
  }

  return false;
}

bool isTrivialBuiltinFunction(const FunctionDecl *F) {
  if (!F || !F->getDeclName().isIdentifier())
    return false;
  auto Name = F->getName();
  return Name.starts_with("__builtin") || Name == "__libcpp_verbose_abort" ||
         Name.starts_with("os_log") || Name.starts_with("_os_log");
}

bool isSingleton(const NamedDecl *F) {
  assert(F);
  // FIXME: check # of params == 1
  if (auto *MethodDecl = dyn_cast<CXXMethodDecl>(F)) {
    if (!MethodDecl->isStatic())
      return false;
  }
  const auto &NameStr = safeGetName(F);
  StringRef Name = NameStr; // FIXME: Make safeGetName return StringRef.
  return Name == "singleton" || Name.ends_with("Singleton");
}

// We only care about statements so let's use the simple
// (non-recursive) visitor.
class TrivialFunctionAnalysisVisitor
    : public ConstStmtVisitor<TrivialFunctionAnalysisVisitor, bool> {

  // Returns false if at least one child is non-trivial.
  bool VisitChildren(const Stmt *S) {
    for (const Stmt *Child : S->children()) {
      if (Child && !Visit(Child)) {
        if (OffendingStmt && !*OffendingStmt)
          *OffendingStmt = Child;
        return false;
      }
    }

    return true;
  }

  template <typename StmtOrDecl, typename CheckFunction>
  bool WithCachedResult(const StmtOrDecl *S, CheckFunction Function) {
    auto CacheIt = Cache.find(S);
    if (CacheIt != Cache.end() && !OffendingStmt)
      return CacheIt->second;

    // Treat a recursive statement to be trivial until proven otherwise.
    auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(S, true));
    if (!IsNew)
      return RecursiveIt->second;

    bool Result = Function();

    if (!Result) {
      for (auto &It : RecursiveFn)
        It.second = false;
    }
    RecursiveIt = RecursiveFn.find(S);
    assert(RecursiveIt != RecursiveFn.end());
    Result = RecursiveIt->second;
    RecursiveFn.erase(RecursiveIt);
    Cache[S] = Result;

    return Result;
  }

  bool CanTriviallyDestruct(QualType Ty) {
    if (Ty.isNull())
      return false;

    // T*, T& or T&& does not run its destructor.
    if (Ty->isPointerOrReferenceType())
      return true;

    // Fundamental types (integral, nullptr_t, etc...) don't have destructors.
    if (Ty->isFundamentalType() || Ty->isIntegralOrEnumerationType())
      return true;

    if (const auto *R = Ty->getAsCXXRecordDecl()) {
      // C++ trivially destructible classes are fine.
      if (R->hasDefinition() && R->hasTrivialDestructor())
        return true;

      if (HasFieldWithNonTrivialDtor(R))
        return false;

      // For Webkit, side-effects are fine as long as we don't delete objects,
      // so check recursively.
      if (const auto *Dtor = R->getDestructor())
        return IsFunctionTrivial(Dtor);
    }

    // Structs in C are trivial.
    if (Ty->isRecordType())
      return true;

    // For arrays it depends on the element type.
    // FIXME: We should really use ASTContext::getAsArrayType instead.
    if (const auto *AT = Ty->getAsArrayTypeUnsafe())
      return CanTriviallyDestruct(AT->getElementType());

    return false; // Otherwise it's likely not trivial.
  }

  bool HasFieldWithNonTrivialDtor(const CXXRecordDecl *Cls) {
    auto CacheIt = FieldDtorCache.find(Cls);
    if (CacheIt != FieldDtorCache.end())
      return CacheIt->second;

    bool Result = ([&] {
      auto HasNonTrivialField = [&](const CXXRecordDecl *R) {
        for (const FieldDecl *F : R->fields()) {
          if (!CanTriviallyDestruct(F->getType()))
            return true;
        }
        return false;
      };

      if (HasNonTrivialField(Cls))
        return true;

      if (!Cls->hasDefinition())
        return false;

      CXXBasePaths Paths;
      Paths.setOrigin(const_cast<CXXRecordDecl *>(Cls));
      return Cls->lookupInBases(
          [&](const CXXBaseSpecifier *B, CXXBasePath &) {
            auto *T = B->getType().getTypePtrOrNull();
            if (!T)
              return false;
            auto *R = T->getAsCXXRecordDecl();
            return R && HasNonTrivialField(R);
          },
          Paths, /*LookupInDependent =*/true);
    })();

    FieldDtorCache[Cls] = Result;

    return Result;
  }

public:
  using CacheTy = TrivialFunctionAnalysis::CacheTy;

  TrivialFunctionAnalysisVisitor(CacheTy &Cache,
                                 const Stmt **OffendingStmt = nullptr)
      : Cache(Cache), OffendingStmt(OffendingStmt) {}

  bool IsFunctionTrivial(const Decl *D) {
    const Stmt **SavedOffendingStmt = std::exchange(OffendingStmt, nullptr);
    auto Result = WithCachedResult(D, [&]() {
      if (auto *FnDecl = dyn_cast<FunctionDecl>(D)) {
        if (isNoDeleteFunction(FnDecl))
          return true;
        if (auto *MD = dyn_cast<CXXMethodDecl>(D); MD && MD->isVirtual())
          return false;
        for (auto *Param : FnDecl->parameters()) {
          if (!HasTrivialDestructor(Param))
            return false;
        }
      }
      if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
        for (auto *CtorInit : CtorDecl->inits()) {
          if (!Visit(CtorInit->getInit()))
            return false;
        }
      }
      const Stmt *Body = D->getBody();
      if (!Body)
        return false;
      return Visit(Body);
    });
    OffendingStmt = SavedOffendingStmt;
    return Result;
  }

  bool HasTrivialDestructor(const VarDecl *VD) {
    return WithCachedResult(
        VD, [&] { return CanTriviallyDestruct(VD->getType()); });
  }

  bool IsStatementTrivial(const Stmt *S) {
    auto CacheIt = Cache.find(S);
    if (CacheIt != Cache.end())
      return CacheIt->second;
    bool Result = Visit(S);
    Cache[S] = Result;
    return Result;
  }

  bool VisitStmt(const Stmt *S) {
    // All statements are non-trivial unless overriden later.
    // Don't even recurse into children by default.
    return false;
  }

  bool VisitAttributedStmt(const AttributedStmt *AS) {
    // Ignore attributes.
    return Visit(AS->getSubStmt());
  }

  bool VisitCompoundStmt(const CompoundStmt *CS) {
    // A compound statement is allowed as long each individual sub-statement
    // is trivial.
    return WithCachedResult(CS, [&]() { return VisitChildren(CS); });
  }

  bool VisitCoroutineBodyStmt(const CoroutineBodyStmt *CBS) {
    return WithCachedResult(CBS, [&]() { return VisitChildren(CBS); });
  }

  bool VisitReturnStmt(const ReturnStmt *RS) {
    // A return statement is allowed as long as the return value is trivial.
    if (auto *RV = RS->getRetValue())
      return Visit(RV);
    return true;
  }

  bool VisitDeclStmt(const DeclStmt *DS) {
    for (auto &Decl : DS->decls()) {
      // FIXME: Handle DecompositionDecls.
      if (auto *VD = dyn_cast<VarDecl>(Decl)) {
        if (!HasTrivialDestructor(VD))
          return false;
      }
    }
    return VisitChildren(DS);
  }
  bool VisitDoStmt(const DoStmt *DS) { return VisitChildren(DS); }
  bool VisitIfStmt(const IfStmt *IS) {
    return WithCachedResult(IS, [&]() { return VisitChildren(IS); });
  }
  bool VisitForStmt(const ForStmt *FS) {
    return WithCachedResult(FS, [&]() { return VisitChildren(FS); });
  }
  bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) {
    return WithCachedResult(FS, [&]() { return VisitChildren(FS); });
  }
  bool VisitWhileStmt(const WhileStmt *WS) {
    return WithCachedResult(WS, [&]() { return VisitChildren(WS); });
  }
  bool VisitSwitchStmt(const SwitchStmt *SS) { return VisitChildren(SS); }
  bool VisitCaseStmt(const CaseStmt *CS) { return VisitChildren(CS); }
  bool VisitDefaultStmt(const DefaultStmt *DS) { return VisitChildren(DS); }

  // break, continue, goto, and label statements are always trivial.
  bool VisitBreakStmt(const BreakStmt *) { return true; }
  bool VisitContinueStmt(const ContinueStmt *) { return true; }
  bool VisitGotoStmt(const GotoStmt *) { return true; }
  bool VisitLabelStmt(const LabelStmt *) { return true; }

  bool VisitUnaryOperator(const UnaryOperator *UO) {
    // Unary operators are trivial if its operand is trivial except co_await.
    return UO->getOpcode() != UO_Coawait && Visit(UO->getSubExpr());
  }

  bool VisitBinaryOperator(const BinaryOperator *BO) {
    // Binary operators are trivial if their operands are trivial.
    return Visit(BO->getLHS()) && Visit(BO->getRHS());
  }

  bool VisitCompoundAssignOperator(const CompoundAssignOperator *CAO) {
    // Compound assignment operator such as |= is trivial if its
    // subexpresssions are trivial.
    return VisitChildren(CAO);
  }

  bool VisitArraySubscriptExpr(const ArraySubscriptExpr *ASE) {
    return VisitChildren(ASE);
  }

  bool VisitConditionalOperator(const ConditionalOperator *CO) {
    // Ternary operators are trivial if their conditions & values are trivial.
    return VisitChildren(CO);
  }

  bool VisitAtomicExpr(const AtomicExpr *E) { return VisitChildren(E); }

  bool VisitStaticAssertDecl(const StaticAssertDecl *SAD) {
    // Any static_assert is considered trivial.
    return true;
  }

  bool VisitCallExpr(const CallExpr *CE) {
    if (!checkArguments(CE))
      return false;

    auto *Callee = CE->getDirectCallee();
    if (!Callee)
      return false;

    if (isPtrConversion(Callee))
      return true;

    const auto &Name = safeGetName(Callee);

    if (Callee->isInStdNamespace() &&
        (Name == "addressof" || Name == "forward" || Name == "move"))
      return true;

    if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" ||
        Name == "WTFReportBacktrace" ||
        Name == "WTFCrashWithSecurityImplication" || Name == "WTFCrash" ||
        Name == "WTFReportAssertionFailure" || Name == "isMainThread" ||
        Name == "isMainThreadOrGCThread" || Name == "isMainRunLoop" ||
        Name == "isWebThread" || Name == "isUIThread" ||
        Name == "mayBeGCThread" || Name == "compilerFenceForCrash" ||
        isTrivialBuiltinFunction(Callee))
      return true;

    return IsFunctionTrivial(Callee);
  }

  bool VisitGCCAsmStmt(const GCCAsmStmt *AS) {
    return AS->getAsmString() == "brk #0xc471";
  }

  bool
  VisitSubstNonTypeTemplateParmExpr(const SubstNonTypeTemplateParmExpr *E) {
    // Non-type template paramter is compile time constant and trivial.
    return true;
  }

  bool VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *E) {
    return VisitChildren(E);
  }

  bool VisitPredefinedExpr(const PredefinedExpr *E) {
    // A predefined identifier such as "func" is considered trivial.
    return true;
  }

  bool VisitOffsetOfExpr(const OffsetOfExpr *OE) {
    // offsetof(T, D) is considered trivial.
    return true;
  }

  bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) {
    if (!checkArguments(MCE))
      return false;

    bool TrivialThis = Visit(MCE->getImplicitObjectArgument());
    if (!TrivialThis)
      return false;

    auto *Callee = MCE->getMethodDecl();
    if (!Callee)
      return false;

    if (isa<CXXDestructorDecl>(Callee) &&
        !CanTriviallyDestruct(MCE->getObjectType()))
      return false;

    auto Name = safeGetName(Callee);
    if (Name == "ref" || Name == "incrementCheckedPtrCount")
      return true;

    std::optional<bool> IsGetterOfRefCounted = isGetterOfSafePtr(Callee);
    if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
      return true;

    // Recursively descend into the callee to confirm that it's trivial as well.
    return IsFunctionTrivial(Callee);
  }

  bool VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) {
    if (!checkArguments(OCE))
      return false;
    auto *Callee = OCE->getCalleeDecl();
    if (!Callee)
      return false;
    // Recursively descend into the callee to confirm that it's trivial as well.
    return IsFunctionTrivial(Callee);
  }

  bool VisitCXXRewrittenBinaryOperator(const CXXRewrittenBinaryOperator *Op) {
    auto *SemanticExpr = Op->getSemanticForm();
    return SemanticExpr && Visit(SemanticExpr);
  }

  bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
    if (auto *Expr = E->getExpr()) {
      if (!Visit(Expr))
        return false;
    }
    return true;
  }

  bool VisitCXXDefaultInitExpr(const CXXDefaultInitExpr *E) {
    return Visit(E->getExpr());
  }

  bool checkArguments(const CallExpr *CE) {
    for (const Expr *Arg : CE->arguments()) {
      if (Arg && !Visit(Arg))
        return false;
    }
    return true;
  }

  bool VisitCXXConstructExpr(const CXXConstructExpr *CE) {
    for (const Expr *Arg : CE->arguments()) {
      if (Arg && !Visit(Arg))
        return false;
    }

    // Recursively descend into the callee to confirm that it's trivial.
    return IsFunctionTrivial(CE->getConstructor());
  }

  bool VisitCXXDeleteExpr(const CXXDeleteExpr *DE) {
    return CanTriviallyDestruct(DE->getDestroyedType());
  }

  bool VisitCXXInheritedCtorInitExpr(const CXXInheritedCtorInitExpr *E) {
    return IsFunctionTrivial(E->getConstructor());
  }

  bool VisitCXXNewExpr(const CXXNewExpr *NE) { return VisitChildren(NE); }

  bool VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
    return Visit(ICE->getSubExpr());
  }

  bool VisitExplicitCastExpr(const ExplicitCastExpr *ECE) {
    return Visit(ECE->getSubExpr());
  }

  bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *VMT) {
    return Visit(VMT->getSubExpr());
  }

  bool VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *BTE) {
    if (auto *Temp = BTE->getTemporary()) {
      if (!IsFunctionTrivial(Temp->getDestructor()))
        return false;
    }
    return Visit(BTE->getSubExpr());
  }

  bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *AILE) {
    return Visit(AILE->getCommonExpr()) && Visit(AILE->getSubExpr());
  }

  bool VisitArrayInitIndexExpr(const ArrayInitIndexExpr *AIIE) {
    return true; // The current array index in VisitArrayInitLoopExpr is always
                 // trivial.
  }

  bool VisitOpaqueValueExpr(const OpaqueValueExpr *OVE) {
    return Visit(OVE->getSourceExpr());
  }

  bool VisitExprWithCleanups(const ExprWithCleanups *EWC) {
    return Visit(EWC->getSubExpr());
  }

  bool VisitParenExpr(const ParenExpr *PE) { return Visit(PE->getSubExpr()); }

  bool VisitInitListExpr(const InitListExpr *ILE) {
    for (const Expr *Child : ILE->inits()) {
      if (Child && !Visit(Child))
        return false;
    }
    return true;
  }

  bool VisitMemberExpr(const MemberExpr *ME) {
    // Field access is allowed but the base pointer may itself be non-trivial.
    return Visit(ME->getBase());
  }

  bool VisitCXXThisExpr(const CXXThisExpr *CTE) {
    // The expression 'this' is always trivial, be it explicit or implicit.
    return true;
  }

  bool VisitCXXNullPtrLiteralExpr(const CXXNullPtrLiteralExpr *E) {
    // nullptr is trivial.
    return true;
  }

  bool VisitDeclRefExpr(const DeclRefExpr *DRE) {
    // The use of a variable is trivial.
    return true;
  }

  // Constant literal expressions are always trivial
  bool VisitIntegerLiteral(const IntegerLiteral *E) { return true; }
  bool VisitFloatingLiteral(const FloatingLiteral *E) { return true; }
  bool VisitFixedPointLiteral(const FixedPointLiteral *E) { return true; }
  bool VisitCharacterLiteral(const CharacterLiteral *E) { return true; }
  bool VisitStringLiteral(const StringLiteral *E) { return true; }
  bool VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *E) { return true; }

  bool VisitConstantExpr(const ConstantExpr *CE) {
    // Constant expressions are trivial.
    return true;
  }

  bool VisitImplicitValueInitExpr(const ImplicitValueInitExpr *IVIE) {
    // An implicit value initialization is trvial.
    return true;
  }

private:
  CacheTy &Cache;
  CacheTy FieldDtorCache;
  CacheTy RecursiveFn;
  const Stmt **OffendingStmt;
};

bool TrivialFunctionAnalysis::isTrivialImpl(
    const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache,
    const Stmt **OffendingStmt) {
  TrivialFunctionAnalysisVisitor V(Cache, OffendingStmt);
  return V.IsFunctionTrivial(D);
}

bool TrivialFunctionAnalysis::isTrivialImpl(
    const Stmt *S, TrivialFunctionAnalysis::CacheTy &Cache,
    const Stmt **OffendingStmt) {
  TrivialFunctionAnalysisVisitor V(Cache, OffendingStmt);
  return V.IsStatementTrivial(S);
}

bool TrivialFunctionAnalysis::hasTrivialDtorImpl(const VarDecl *VD,
                                                 CacheTy &Cache) {
  TrivialFunctionAnalysisVisitor V(Cache);
  return V.HasTrivialDestructor(VD);
}

} // namespace clang
