//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Emit Stmt nodes as CIR code.
//
//===----------------------------------------------------------------------===//

#include "CIRGenBuilder.h"
#include "CIRGenFunction.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/CIR/MissingFeatures.h"

using namespace clang;
using namespace clang::CIRGen;
using namespace cir;

static mlir::LogicalResult emitStmtWithResult(CIRGenFunction &cgf,
                                              const Stmt *exprResult,
                                              AggValueSlot slot,
                                              Address *lastValue) {
  // We have to special case labels here. They are statements, but when put
  // at the end of a statement expression, they yield the value of their
  // subexpression. Handle this by walking through all labels we encounter,
  // emitting them before we evaluate the subexpr.
  // Similar issues arise for attributed statements.
  while (!isa<Expr>(exprResult)) {
    if (const auto *ls = dyn_cast<LabelStmt>(exprResult)) {
      if (cgf.emitLabel(*ls->getDecl()).failed())
        return mlir::failure();
      exprResult = ls->getSubStmt();
    } else if (const auto *as = dyn_cast<AttributedStmt>(exprResult)) {
      // FIXME: Update this if we ever have attributes that affect the
      // semantics of an expression.
      exprResult = as->getSubStmt();
    } else {
      llvm_unreachable("Unknown value statement");
    }
  }

  const Expr *e = cast<Expr>(exprResult);
  QualType exprTy = e->getType();
  if (cgf.hasAggregateEvaluationKind(exprTy)) {
    cgf.emitAggExpr(e, slot);
  } else {
    // We can't return an RValue here because there might be cleanups at
    // the end of the StmtExpr.  Because of that, we have to emit the result
    // here into a temporary alloca.
    cgf.emitAnyExprToMem(e, *lastValue, Qualifiers(),
                         /*IsInit*/ false);
  }

  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitCompoundStmtWithoutScope(
    const CompoundStmt &s, Address *lastValue, AggValueSlot slot) {
  mlir::LogicalResult result = mlir::success();
  const Stmt *exprResult = s.body_back();
  assert((!lastValue || (lastValue && exprResult)) &&
         "If lastValue is not null then the CompoundStmt must have a "
         "StmtExprResult");

  for (const Stmt *curStmt : s.body()) {
    const bool saveResult = lastValue && exprResult == curStmt;
    if (saveResult) {
      if (emitStmtWithResult(*this, exprResult, slot, lastValue).failed())
        result = mlir::failure();
    } else {
      if (emitStmt(curStmt, /*useCurrentScope=*/false).failed())
        result = mlir::failure();
    }
  }
  return result;
}

mlir::LogicalResult
CIRGenFunction::emitAttributedStmt(const AttributedStmt &s) {
  for (const Attr *attr : s.getAttrs()) {
    switch (attr->getKind()) {
    default:
      break;
    case attr::NoMerge:
    case attr::NoInline:
    case attr::AlwaysInline:
    case attr::NoConvergent:
    case attr::MustTail:
    case attr::Atomic:
    case attr::HLSLControlFlowHint:
      cgm.errorNYI(s.getSourceRange(),
                   "Unimplemented statement attribute: ", attr->getKind());
      break;
    case attr::CXXAssume: {
      const Expr *assumptionExpr = cast<CXXAssumeAttr>(attr)->getAssumption();
      if (getLangOpts().CXXAssumptions && builder.getInsertionBlock() &&
          !assumptionExpr->HasSideEffects(getContext())) {
        mlir::Value assumptionValue = emitCheckedArgForAssume(assumptionExpr);
        cir::AssumeOp::create(builder, getLoc(s.getSourceRange()),
                              assumptionValue);
      }
    } break;
    }
  }

  return emitStmt(s.getSubStmt(), /*useCurrentScope=*/true, s.getAttrs());
}

mlir::LogicalResult CIRGenFunction::emitCompoundStmt(const CompoundStmt &s,
                                                     Address *lastValue,
                                                     AggValueSlot slot) {
  // Add local scope to track new declared variables.
  SymTableScopeTy varScope(symbolTable);
  mlir::Location scopeLoc = getLoc(s.getSourceRange());
  mlir::OpBuilder::InsertPoint scopeInsPt;
  cir::ScopeOp::create(
      builder, scopeLoc,
      [&](mlir::OpBuilder &b, mlir::Type &type, mlir::Location loc) {
        scopeInsPt = b.saveInsertionPoint();
      });
  mlir::OpBuilder::InsertionGuard guard(builder);
  builder.restoreInsertionPoint(scopeInsPt);
  LexicalScope lexScope(*this, scopeLoc, builder.getInsertionBlock());
  return emitCompoundStmtWithoutScope(s, lastValue, slot);
}

void CIRGenFunction::emitStopPoint(const Stmt *s) {
  assert(!cir::MissingFeatures::generateDebugInfo());
}

// Build CIR for a statement. useCurrentScope should be true if no new scopes
// need to be created when finding a compound statement.
mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
                                             bool useCurrentScope,
                                             ArrayRef<const Attr *> attr) {
  if (mlir::succeeded(emitSimpleStmt(s, useCurrentScope)))
    return mlir::success();

  switch (s->getStmtClass()) {
  case Stmt::NoStmtClass:
  case Stmt::CXXCatchStmtClass:
  case Stmt::SEHExceptStmtClass:
  case Stmt::SEHFinallyStmtClass:
  case Stmt::MSDependentExistsStmtClass:
    llvm_unreachable("invalid statement class to emit generically");
  case Stmt::BreakStmtClass:
  case Stmt::NullStmtClass:
  case Stmt::CompoundStmtClass:
  case Stmt::ContinueStmtClass:
  case Stmt::DeclStmtClass:
  case Stmt::ReturnStmtClass:
    llvm_unreachable("should have emitted these statements as simple");

#define STMT(Type, Base)
#define ABSTRACT_STMT(Op)
#define EXPR(Type, Base) case Stmt::Type##Class:
#include "clang/AST/StmtNodes.inc"
    {
      assert(builder.getInsertionBlock() &&
             "expression emission must have an insertion point");

      emitIgnoredExpr(cast<Expr>(s));

      // Classic codegen has a check here to see if the emitter created a new
      // block that isn't used (comparing the incoming and outgoing insertion
      // points) and deletes the outgoing block if it's not used. In CIR, we
      // will handle that during the cir.canonicalize pass.
      return mlir::success();
    }
  case Stmt::IfStmtClass:
    return emitIfStmt(cast<IfStmt>(*s));
  case Stmt::SwitchStmtClass:
    return emitSwitchStmt(cast<SwitchStmt>(*s));
  case Stmt::ForStmtClass:
    return emitForStmt(cast<ForStmt>(*s));
  case Stmt::WhileStmtClass:
    return emitWhileStmt(cast<WhileStmt>(*s));
  case Stmt::DoStmtClass:
    return emitDoStmt(cast<DoStmt>(*s));
  case Stmt::CXXTryStmtClass:
    return emitCXXTryStmt(cast<CXXTryStmt>(*s));
  case Stmt::CXXForRangeStmtClass:
    return emitCXXForRangeStmt(cast<CXXForRangeStmt>(*s), attr);
  case Stmt::CoroutineBodyStmtClass:
    return emitCoroutineBody(cast<CoroutineBodyStmt>(*s));
  case Stmt::IndirectGotoStmtClass:
    return emitIndirectGotoStmt(cast<IndirectGotoStmt>(*s));
  case Stmt::CoreturnStmtClass:
    return emitCoreturnStmt(cast<CoreturnStmt>(*s));
  case Stmt::OpenACCComputeConstructClass:
    return emitOpenACCComputeConstruct(cast<OpenACCComputeConstruct>(*s));
  case Stmt::OpenACCLoopConstructClass:
    return emitOpenACCLoopConstruct(cast<OpenACCLoopConstruct>(*s));
  case Stmt::OpenACCCombinedConstructClass:
    return emitOpenACCCombinedConstruct(cast<OpenACCCombinedConstruct>(*s));
  case Stmt::OpenACCDataConstructClass:
    return emitOpenACCDataConstruct(cast<OpenACCDataConstruct>(*s));
  case Stmt::OpenACCEnterDataConstructClass:
    return emitOpenACCEnterDataConstruct(cast<OpenACCEnterDataConstruct>(*s));
  case Stmt::OpenACCExitDataConstructClass:
    return emitOpenACCExitDataConstruct(cast<OpenACCExitDataConstruct>(*s));
  case Stmt::OpenACCHostDataConstructClass:
    return emitOpenACCHostDataConstruct(cast<OpenACCHostDataConstruct>(*s));
  case Stmt::OpenACCWaitConstructClass:
    return emitOpenACCWaitConstruct(cast<OpenACCWaitConstruct>(*s));
  case Stmt::OpenACCInitConstructClass:
    return emitOpenACCInitConstruct(cast<OpenACCInitConstruct>(*s));
  case Stmt::OpenACCShutdownConstructClass:
    return emitOpenACCShutdownConstruct(cast<OpenACCShutdownConstruct>(*s));
  case Stmt::OpenACCSetConstructClass:
    return emitOpenACCSetConstruct(cast<OpenACCSetConstruct>(*s));
  case Stmt::OpenACCUpdateConstructClass:
    return emitOpenACCUpdateConstruct(cast<OpenACCUpdateConstruct>(*s));
  case Stmt::OpenACCCacheConstructClass:
    return emitOpenACCCacheConstruct(cast<OpenACCCacheConstruct>(*s));
  case Stmt::OpenACCAtomicConstructClass:
    return emitOpenACCAtomicConstruct(cast<OpenACCAtomicConstruct>(*s));
  case Stmt::GCCAsmStmtClass:
  case Stmt::MSAsmStmtClass:
    return emitAsmStmt(cast<AsmStmt>(*s));
  case Stmt::OMPScopeDirectiveClass:
    return emitOMPScopeDirective(cast<OMPScopeDirective>(*s));
  case Stmt::OMPErrorDirectiveClass:
    return emitOMPErrorDirective(cast<OMPErrorDirective>(*s));
  case Stmt::OMPParallelDirectiveClass:
    return emitOMPParallelDirective(cast<OMPParallelDirective>(*s));
  case Stmt::OMPTaskwaitDirectiveClass:
    return emitOMPTaskwaitDirective(cast<OMPTaskwaitDirective>(*s));
  case Stmt::OMPTaskyieldDirectiveClass:
    return emitOMPTaskyieldDirective(cast<OMPTaskyieldDirective>(*s));
  case Stmt::OMPBarrierDirectiveClass:
    return emitOMPBarrierDirective(cast<OMPBarrierDirective>(*s));
  case Stmt::OMPMetaDirectiveClass:
    return emitOMPMetaDirective(cast<OMPMetaDirective>(*s));
  case Stmt::OMPCanonicalLoopClass:
    return emitOMPCanonicalLoop(cast<OMPCanonicalLoop>(*s));
  case Stmt::OMPSimdDirectiveClass:
    return emitOMPSimdDirective(cast<OMPSimdDirective>(*s));
  case Stmt::OMPTileDirectiveClass:
    return emitOMPTileDirective(cast<OMPTileDirective>(*s));
  case Stmt::OMPUnrollDirectiveClass:
    return emitOMPUnrollDirective(cast<OMPUnrollDirective>(*s));
  case Stmt::OMPFuseDirectiveClass:
    return emitOMPFuseDirective(cast<OMPFuseDirective>(*s));
  case Stmt::OMPForDirectiveClass:
    return emitOMPForDirective(cast<OMPForDirective>(*s));
  case Stmt::OMPForSimdDirectiveClass:
    return emitOMPForSimdDirective(cast<OMPForSimdDirective>(*s));
  case Stmt::OMPSectionsDirectiveClass:
    return emitOMPSectionsDirective(cast<OMPSectionsDirective>(*s));
  case Stmt::OMPSectionDirectiveClass:
    return emitOMPSectionDirective(cast<OMPSectionDirective>(*s));
  case Stmt::OMPSingleDirectiveClass:
    return emitOMPSingleDirective(cast<OMPSingleDirective>(*s));
  case Stmt::OMPMasterDirectiveClass:
    return emitOMPMasterDirective(cast<OMPMasterDirective>(*s));
  case Stmt::OMPCriticalDirectiveClass:
    return emitOMPCriticalDirective(cast<OMPCriticalDirective>(*s));
  case Stmt::OMPParallelForDirectiveClass:
    return emitOMPParallelForDirective(cast<OMPParallelForDirective>(*s));
  case Stmt::OMPParallelForSimdDirectiveClass:
    return emitOMPParallelForSimdDirective(
        cast<OMPParallelForSimdDirective>(*s));
  case Stmt::OMPParallelMasterDirectiveClass:
    return emitOMPParallelMasterDirective(cast<OMPParallelMasterDirective>(*s));
  case Stmt::OMPParallelSectionsDirectiveClass:
    return emitOMPParallelSectionsDirective(
        cast<OMPParallelSectionsDirective>(*s));
  case Stmt::OMPTaskDirectiveClass:
    return emitOMPTaskDirective(cast<OMPTaskDirective>(*s));
  case Stmt::OMPTaskgroupDirectiveClass:
    return emitOMPTaskgroupDirective(cast<OMPTaskgroupDirective>(*s));
  case Stmt::OMPFlushDirectiveClass:
    return emitOMPFlushDirective(cast<OMPFlushDirective>(*s));
  case Stmt::OMPDepobjDirectiveClass:
    return emitOMPDepobjDirective(cast<OMPDepobjDirective>(*s));
  case Stmt::OMPScanDirectiveClass:
    return emitOMPScanDirective(cast<OMPScanDirective>(*s));
  case Stmt::OMPOrderedDirectiveClass:
    return emitOMPOrderedDirective(cast<OMPOrderedDirective>(*s));
  case Stmt::OMPAtomicDirectiveClass:
    return emitOMPAtomicDirective(cast<OMPAtomicDirective>(*s));
  case Stmt::OMPTargetDirectiveClass:
    return emitOMPTargetDirective(cast<OMPTargetDirective>(*s));
  case Stmt::OMPTeamsDirectiveClass:
    return emitOMPTeamsDirective(cast<OMPTeamsDirective>(*s));
  case Stmt::OMPCancellationPointDirectiveClass:
    return emitOMPCancellationPointDirective(
        cast<OMPCancellationPointDirective>(*s));
  case Stmt::OMPCancelDirectiveClass:
    return emitOMPCancelDirective(cast<OMPCancelDirective>(*s));
  case Stmt::OMPTargetDataDirectiveClass:
    return emitOMPTargetDataDirective(cast<OMPTargetDataDirective>(*s));
  case Stmt::OMPTargetEnterDataDirectiveClass:
    return emitOMPTargetEnterDataDirective(
        cast<OMPTargetEnterDataDirective>(*s));
  case Stmt::OMPTargetExitDataDirectiveClass:
    return emitOMPTargetExitDataDirective(cast<OMPTargetExitDataDirective>(*s));
  case Stmt::OMPTargetParallelDirectiveClass:
    return emitOMPTargetParallelDirective(cast<OMPTargetParallelDirective>(*s));
  case Stmt::OMPTargetParallelForDirectiveClass:
    return emitOMPTargetParallelForDirective(
        cast<OMPTargetParallelForDirective>(*s));
  case Stmt::OMPTaskLoopDirectiveClass:
    return emitOMPTaskLoopDirective(cast<OMPTaskLoopDirective>(*s));
  case Stmt::OMPTaskLoopSimdDirectiveClass:
    return emitOMPTaskLoopSimdDirective(cast<OMPTaskLoopSimdDirective>(*s));
  case Stmt::OMPMaskedTaskLoopDirectiveClass:
    return emitOMPMaskedTaskLoopDirective(cast<OMPMaskedTaskLoopDirective>(*s));
  case Stmt::OMPMaskedTaskLoopSimdDirectiveClass:
    return emitOMPMaskedTaskLoopSimdDirective(
        cast<OMPMaskedTaskLoopSimdDirective>(*s));
  case Stmt::OMPMasterTaskLoopDirectiveClass:
    return emitOMPMasterTaskLoopDirective(cast<OMPMasterTaskLoopDirective>(*s));
  case Stmt::OMPMasterTaskLoopSimdDirectiveClass:
    return emitOMPMasterTaskLoopSimdDirective(
        cast<OMPMasterTaskLoopSimdDirective>(*s));
  case Stmt::OMPParallelGenericLoopDirectiveClass:
    return emitOMPParallelGenericLoopDirective(
        cast<OMPParallelGenericLoopDirective>(*s));
  case Stmt::OMPParallelMaskedDirectiveClass:
    return emitOMPParallelMaskedDirective(cast<OMPParallelMaskedDirective>(*s));
  case Stmt::OMPParallelMaskedTaskLoopDirectiveClass:
    return emitOMPParallelMaskedTaskLoopDirective(
        cast<OMPParallelMaskedTaskLoopDirective>(*s));
  case Stmt::OMPParallelMaskedTaskLoopSimdDirectiveClass:
    return emitOMPParallelMaskedTaskLoopSimdDirective(
        cast<OMPParallelMaskedTaskLoopSimdDirective>(*s));
  case Stmt::OMPParallelMasterTaskLoopDirectiveClass:
    return emitOMPParallelMasterTaskLoopDirective(
        cast<OMPParallelMasterTaskLoopDirective>(*s));
  case Stmt::OMPParallelMasterTaskLoopSimdDirectiveClass:
    return emitOMPParallelMasterTaskLoopSimdDirective(
        cast<OMPParallelMasterTaskLoopSimdDirective>(*s));
  case Stmt::OMPDistributeDirectiveClass:
    return emitOMPDistributeDirective(cast<OMPDistributeDirective>(*s));
  case Stmt::OMPDistributeParallelForDirectiveClass:
    return emitOMPDistributeParallelForDirective(
        cast<OMPDistributeParallelForDirective>(*s));
  case Stmt::OMPDistributeParallelForSimdDirectiveClass:
    return emitOMPDistributeParallelForSimdDirective(
        cast<OMPDistributeParallelForSimdDirective>(*s));
  case Stmt::OMPDistributeSimdDirectiveClass:
    return emitOMPDistributeSimdDirective(cast<OMPDistributeSimdDirective>(*s));
  case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
    return emitOMPTargetParallelGenericLoopDirective(
        cast<OMPTargetParallelGenericLoopDirective>(*s));
  case Stmt::OMPTargetParallelForSimdDirectiveClass:
    return emitOMPTargetParallelForSimdDirective(
        cast<OMPTargetParallelForSimdDirective>(*s));
  case Stmt::OMPTargetSimdDirectiveClass:
    return emitOMPTargetSimdDirective(cast<OMPTargetSimdDirective>(*s));
  case Stmt::OMPTargetTeamsGenericLoopDirectiveClass:
    return emitOMPTargetTeamsGenericLoopDirective(
        cast<OMPTargetTeamsGenericLoopDirective>(*s));
  case Stmt::OMPTargetUpdateDirectiveClass:
    return emitOMPTargetUpdateDirective(cast<OMPTargetUpdateDirective>(*s));
  case Stmt::OMPTeamsDistributeDirectiveClass:
    return emitOMPTeamsDistributeDirective(
        cast<OMPTeamsDistributeDirective>(*s));
  case Stmt::OMPTeamsDistributeSimdDirectiveClass:
    return emitOMPTeamsDistributeSimdDirective(
        cast<OMPTeamsDistributeSimdDirective>(*s));
  case Stmt::OMPTeamsDistributeParallelForSimdDirectiveClass:
    return emitOMPTeamsDistributeParallelForSimdDirective(
        cast<OMPTeamsDistributeParallelForSimdDirective>(*s));
  case Stmt::OMPTeamsDistributeParallelForDirectiveClass:
    return emitOMPTeamsDistributeParallelForDirective(
        cast<OMPTeamsDistributeParallelForDirective>(*s));
  case Stmt::OMPTeamsGenericLoopDirectiveClass:
    return emitOMPTeamsGenericLoopDirective(
        cast<OMPTeamsGenericLoopDirective>(*s));
  case Stmt::OMPTargetTeamsDirectiveClass:
    return emitOMPTargetTeamsDirective(cast<OMPTargetTeamsDirective>(*s));
  case Stmt::OMPTargetTeamsDistributeDirectiveClass:
    return emitOMPTargetTeamsDistributeDirective(
        cast<OMPTargetTeamsDistributeDirective>(*s));
  case Stmt::OMPTargetTeamsDistributeParallelForDirectiveClass:
    return emitOMPTargetTeamsDistributeParallelForDirective(
        cast<OMPTargetTeamsDistributeParallelForDirective>(*s));
  case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass:
    return emitOMPTargetTeamsDistributeParallelForSimdDirective(
        cast<OMPTargetTeamsDistributeParallelForSimdDirective>(*s));
  case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass:
    return emitOMPTargetTeamsDistributeSimdDirective(
        cast<OMPTargetTeamsDistributeSimdDirective>(*s));
  case Stmt::OMPInteropDirectiveClass:
    return emitOMPInteropDirective(cast<OMPInteropDirective>(*s));
  case Stmt::OMPDispatchDirectiveClass:
    return emitOMPDispatchDirective(cast<OMPDispatchDirective>(*s));
  case Stmt::OMPGenericLoopDirectiveClass:
    return emitOMPGenericLoopDirective(cast<OMPGenericLoopDirective>(*s));
  case Stmt::OMPReverseDirectiveClass:
    return emitOMPReverseDirective(cast<OMPReverseDirective>(*s));
  case Stmt::OMPInterchangeDirectiveClass:
    return emitOMPInterchangeDirective(cast<OMPInterchangeDirective>(*s));
  case Stmt::OMPAssumeDirectiveClass:
    return emitOMPAssumeDirective(cast<OMPAssumeDirective>(*s));
  case Stmt::OMPMaskedDirectiveClass:
    return emitOMPMaskedDirective(cast<OMPMaskedDirective>(*s));
  case Stmt::OMPStripeDirectiveClass:
    return emitOMPStripeDirective(cast<OMPStripeDirective>(*s));
  case Stmt::LabelStmtClass:
  case Stmt::AttributedStmtClass:
  case Stmt::GotoStmtClass:
  case Stmt::DefaultStmtClass:
  case Stmt::CaseStmtClass:
  case Stmt::SEHLeaveStmtClass:
  case Stmt::SYCLKernelCallStmtClass:
  case Stmt::CapturedStmtClass:
  case Stmt::ObjCAtTryStmtClass:
  case Stmt::ObjCAtThrowStmtClass:
  case Stmt::ObjCAtSynchronizedStmtClass:
  case Stmt::ObjCForCollectionStmtClass:
  case Stmt::ObjCAutoreleasePoolStmtClass:
  case Stmt::SEHTryStmtClass:
  case Stmt::ObjCAtCatchStmtClass:
  case Stmt::ObjCAtFinallyStmtClass:
  case Stmt::DeferStmtClass:
    cgm.errorNYI(s->getSourceRange(),
                 std::string("emitStmt: ") + s->getStmtClassName());
    return mlir::failure();
  }

  llvm_unreachable("Unexpected statement class");
}

mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
                                                   bool useCurrentScope) {
  switch (s->getStmtClass()) {
  default:
    return mlir::failure();
  case Stmt::DeclStmtClass:
    return emitDeclStmt(cast<DeclStmt>(*s));
  case Stmt::CompoundStmtClass:
    if (useCurrentScope)
      return emitCompoundStmtWithoutScope(cast<CompoundStmt>(*s));
    return emitCompoundStmt(cast<CompoundStmt>(*s));
  case Stmt::GotoStmtClass:
    return emitGotoStmt(cast<GotoStmt>(*s));
  case Stmt::ContinueStmtClass:
    return emitContinueStmt(cast<ContinueStmt>(*s));

  // NullStmt doesn't need any handling, but we need to say we handled it.
  case Stmt::NullStmtClass:
    break;

  case Stmt::LabelStmtClass:
    return emitLabelStmt(cast<LabelStmt>(*s));
  case Stmt::CaseStmtClass:
  case Stmt::DefaultStmtClass:
    // If we reached here, we must not handling a switch case in the top level.
    return emitSwitchCase(cast<SwitchCase>(*s),
                          /*buildingTopLevelCase=*/false);
    break;

  case Stmt::BreakStmtClass:
    return emitBreakStmt(cast<BreakStmt>(*s));
  case Stmt::ReturnStmtClass:
    return emitReturnStmt(cast<ReturnStmt>(*s));
  case Stmt::AttributedStmtClass:
    return emitAttributedStmt(cast<AttributedStmt>(*s));
  }

  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitLabelStmt(const clang::LabelStmt &s) {

  if (emitLabel(*s.getDecl()).failed())
    return mlir::failure();

  if (getContext().getLangOpts().EHAsynch && s.isSideEntry())
    getCIRGenModule().errorNYI(s.getSourceRange(), "IsEHa: not implemented.");

  return emitStmt(s.getSubStmt(), /*useCurrentScope*/ true);
}

// Add a terminating yield on a body region if no other terminators are used.
static void terminateBody(CIRGenBuilderTy &builder, mlir::Region &r,
                          mlir::Location loc) {
  if (r.empty())
    return;

  SmallVector<mlir::Block *, 4> eraseBlocks;
  unsigned numBlocks = r.getBlocks().size();
  for (auto &block : r.getBlocks()) {
    // Already cleanup after return operations, which might create
    // empty blocks if emitted as last stmt.
    if (numBlocks != 1 && block.empty() && block.hasNoPredecessors() &&
        block.hasNoSuccessors())
      eraseBlocks.push_back(&block);

    if (block.empty() ||
        !block.back().hasTrait<mlir::OpTrait::IsTerminator>()) {
      mlir::OpBuilder::InsertionGuard guardCase(builder);
      builder.setInsertionPointToEnd(&block);
      builder.createYield(loc);
    }
  }

  for (auto *b : eraseBlocks)
    b->erase();
}

mlir::LogicalResult CIRGenFunction::emitIfStmt(const IfStmt &s) {
  mlir::LogicalResult res = mlir::success();
  // The else branch of a consteval if statement is always the only branch
  // that can be runtime evaluated.
  const Stmt *constevalExecuted;
  if (s.isConsteval()) {
    constevalExecuted = s.isNegatedConsteval() ? s.getThen() : s.getElse();
    if (!constevalExecuted) {
      // No runtime code execution required
      return res;
    }
  }

  // C99 6.8.4.1: The first substatement is executed if the expression
  // compares unequal to 0.  The condition must be a scalar type.
  auto ifStmtBuilder = [&]() -> mlir::LogicalResult {
    if (s.isConsteval())
      return emitStmt(constevalExecuted, /*useCurrentScope=*/true);

    if (s.getInit())
      if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
        return mlir::failure();

    if (s.getConditionVariable())
      emitDecl(*s.getConditionVariable());

    // If the condition folds to a constant and this is an 'if constexpr',
    // we simplify it early in CIRGen to avoid emitting the full 'if'.
    bool condConstant;
    if (constantFoldsToBool(s.getCond(), condConstant, s.isConstexpr())) {
      if (s.isConstexpr()) {
        // Handle "if constexpr" explicitly here to avoid generating some
        // ill-formed code since in CIR the "if" is no longer simplified
        // in this lambda like in Clang but postponed to other MLIR
        // passes.
        if (const Stmt *executed = condConstant ? s.getThen() : s.getElse())
          return emitStmt(executed, /*useCurrentScope=*/true);
        // There is nothing to execute at runtime.
        // TODO(cir): there is still an empty cir.scope generated by the caller.
        return mlir::success();
      }
    }

    assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
    assert(!cir::MissingFeatures::incrementProfileCounter());
    return emitIfOnBoolExpr(s.getCond(), s.getThen(), s.getElse());
  };

  // TODO: Add a new scoped symbol table.
  // LexicalScope ConditionScope(*this, S.getCond()->getSourceRange());
  // The if scope contains the full source range for IfStmt.
  mlir::Location scopeLoc = getLoc(s.getSourceRange());
  cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                       [&](mlir::OpBuilder &b, mlir::Location loc) {
                         LexicalScope lexScope{*this, scopeLoc,
                                               builder.getInsertionBlock()};
                         res = ifStmtBuilder();
                       });

  return res;
}

mlir::LogicalResult CIRGenFunction::emitDeclStmt(const DeclStmt &s) {
  assert(builder.getInsertionBlock() && "expected valid insertion point");

  for (const Decl *i : s.decls())
    emitDecl(*i, /*evaluateConditionDecl=*/true);

  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitReturnStmt(const ReturnStmt &s) {
  mlir::Location loc = getLoc(s.getSourceRange());
  const Expr *rv = s.getRetValue();

  RunCleanupsScope cleanupScope(*this);
  bool createNewScope = false;
  if (const auto *ewc = dyn_cast_or_null<ExprWithCleanups>(rv)) {
    rv = ewc->getSubExpr();
    createNewScope = true;
  }

  auto handleReturnVal = [&]() {
    if (getContext().getLangOpts().ElideConstructors && s.getNRVOCandidate() &&
        s.getNRVOCandidate()->isNRVOVariable()) {
      assert(!cir::MissingFeatures::openMP());
      // Apply the named return value optimization for this return statement,
      // which means doing nothing: the appropriate result has already been
      // constructed into the NRVO variable.

      // If there is an NRVO flag for this variable, set it to 1 into indicate
      // that the cleanup code should not destroy the variable.
      if (auto nrvoFlag = nrvoFlags[s.getNRVOCandidate()])
        builder.createFlagStore(loc, true, nrvoFlag);
    } else if (!rv) {
      // No return expression. Do nothing.
    } else if (rv->getType()->isVoidType()) {
      // Make sure not to return anything, but evaluate the expression
      // for side effects.
      if (rv) {
        emitAnyExpr(rv);
      }
    } else if (cast<FunctionDecl>(curGD.getDecl())
                   ->getReturnType()
                   ->isReferenceType()) {
      // If this function returns a reference, take the address of the
      // expression rather than the value.
      RValue result = emitReferenceBindingToExpr(rv);
      builder.CIRBaseBuilderTy::createStore(loc, result.getValue(),
                                            *fnRetAlloca);
    } else {
      mlir::Value value = nullptr;
      switch (CIRGenFunction::getEvaluationKind(rv->getType())) {
      case cir::TEK_Scalar:
        value = emitScalarExpr(rv);
        if (value) { // Change this to an assert once emitScalarExpr is complete
          builder.CIRBaseBuilderTy::createStore(loc, value, *fnRetAlloca);
        }
        break;
      case cir::TEK_Complex:
        emitComplexExprIntoLValue(rv,
                                  makeAddrLValue(returnValue, rv->getType()),
                                  /*isInit=*/true);
        break;
      case cir::TEK_Aggregate:
        assert(!cir::MissingFeatures::aggValueSlotGC());
        emitAggExpr(rv, AggValueSlot::forAddr(returnValue, Qualifiers(),
                                              AggValueSlot::IsDestructed,
                                              AggValueSlot::IsNotAliased,
                                              getOverlapForReturnValue()));
        break;
      }
    }
  };

  if (!createNewScope) {
    handleReturnVal();
  } else {
    mlir::Location scopeLoc =
        getLoc(rv ? rv->getSourceRange() : s.getSourceRange());
    // First create cir.scope and later emit it's body. Otherwise all CIRGen
    // dispatched by `handleReturnVal()` might needs to manipulate blocks and
    // look into parents, which are all unlinked.
    mlir::OpBuilder::InsertPoint scopeBody;
    cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                         [&](mlir::OpBuilder &b, mlir::Location loc) {
                           scopeBody = b.saveInsertionPoint();
                         });
    {
      mlir::OpBuilder::InsertionGuard guard(builder);
      builder.restoreInsertionPoint(scopeBody);
      CIRGenFunction::LexicalScope lexScope{*this, scopeLoc,
                                            builder.getInsertionBlock()};
      handleReturnVal();
    }
  }

  cleanupScope.forceCleanup();

  // Classic codegen emits a branch through any cleanups before continuing to
  // a shared return block. Because CIR handles branching through cleanups
  // during the CFG flattening phase, we can just emit the return statement
  // directly.
  // TODO(cir): Eliminate this redundant load and the store above when we can.
  if (fnRetAlloca) {
    // Load the value from `__retval` and return it via the `cir.return` op.
    cir::AllocaOp retAlloca =
        mlir::cast<cir::AllocaOp>(fnRetAlloca->getDefiningOp());
    auto value = cir::LoadOp::create(builder, loc, retAlloca.getAllocaType(),
                                     *fnRetAlloca);

    cir::ReturnOp::create(builder, loc, {value});
  } else {
    cir::ReturnOp::create(builder, loc);
  }

  // Insert the new block to continue codegen after the return statement.
  // This will get deleted if we don't populate it. This handles the case of
  // unreachable statements below a return.
  builder.createBlock(builder.getBlock()->getParent());
  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitGotoStmt(const clang::GotoStmt &s) {
  // FIXME: LLVM codegen inserts emit a stop point here for debug info
  // sake when the insertion point is available, but doesn't do
  // anything special when there isn't. We haven't implemented debug
  // info support just yet, look at this again once we have it.
  assert(!cir::MissingFeatures::generateDebugInfo());

  cir::GotoOp::create(builder, getLoc(s.getSourceRange()),
                      s.getLabel()->getName());

  // A goto marks the end of a block, create a new one for codegen after
  // emitGotoStmt can resume building in that block.
  // Insert the new block to continue codegen after goto.
  builder.createBlock(builder.getBlock()->getParent());

  return mlir::success();
}

mlir::LogicalResult
CIRGenFunction::emitIndirectGotoStmt(const IndirectGotoStmt &s) {
  mlir::Value val = emitScalarExpr(s.getTarget());
  assert(indirectGotoBlock &&
         "If you jumping to a indirect branch should be alareadye emitted");
  cir::BrOp::create(builder, getLoc(s.getSourceRange()), indirectGotoBlock,
                    val);
  builder.createBlock(builder.getBlock()->getParent());
  return mlir::success();
}

mlir::LogicalResult
CIRGenFunction::emitContinueStmt(const clang::ContinueStmt &s) {
  builder.createContinue(getLoc(s.getKwLoc()));

  // Insert the new block to continue codegen after the continue statement.
  builder.createBlock(builder.getBlock()->getParent());

  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitLabel(const clang::LabelDecl &d) {
  // Create a new block to tag with a label and add a branch from
  // the current one to it. If the block is empty just call attach it
  // to this label.
  mlir::Block *currBlock = builder.getBlock();
  mlir::Block *labelBlock = currBlock;

  if (!currBlock->empty() || currBlock->isEntryBlock()) {
    {
      mlir::OpBuilder::InsertionGuard guard(builder);
      labelBlock = builder.createBlock(builder.getBlock()->getParent());
    }
    cir::BrOp::create(builder, getLoc(d.getSourceRange()), labelBlock);
  }

  builder.setInsertionPointToEnd(labelBlock);
  cir::LabelOp label =
      cir::LabelOp::create(builder, getLoc(d.getSourceRange()), d.getName());
  builder.setInsertionPointToEnd(labelBlock);
  auto func = cast<cir::FuncOp>(curFn);
  cgm.mapBlockAddress(cir::BlockAddrInfoAttr::get(builder.getContext(),
                                                  func.getSymNameAttr(),
                                                  label.getLabelAttr()),
                      label);
  //  FIXME: emit debug info for labels, incrementProfileCounter
  assert(!cir::MissingFeatures::incrementProfileCounter());
  assert(!cir::MissingFeatures::generateDebugInfo());
  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitBreakStmt(const clang::BreakStmt &s) {
  builder.createBreak(getLoc(s.getKwLoc()));

  // Insert the new block to continue codegen after the break statement.
  builder.createBlock(builder.getBlock()->getParent());

  return mlir::success();
}

template <typename T>
mlir::LogicalResult
CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
                                       mlir::ArrayAttr value, CaseOpKind kind,
                                       bool buildingTopLevelCase) {

  assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
         "only case or default stmt go here");

  mlir::LogicalResult result = mlir::success();

  mlir::Location loc = getLoc(stmt->getBeginLoc());

  enum class SubStmtKind { Case, Default, Other };
  SubStmtKind subStmtKind = SubStmtKind::Other;
  const Stmt *sub = stmt->getSubStmt();

  mlir::OpBuilder::InsertPoint insertPoint;
  CaseOp::create(builder, loc, value, kind, insertPoint);

  {
    mlir::OpBuilder::InsertionGuard guardSwitch(builder);
    builder.restoreInsertionPoint(insertPoint);

    if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
      subStmtKind = SubStmtKind::Default;
      builder.createYield(loc);
    } else if (isa<CaseStmt>(sub) && isa<DefaultStmt, CaseStmt>(stmt)) {
      subStmtKind = SubStmtKind::Case;
      builder.createYield(loc);
    } else {
      result = emitStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
    }

    insertPoint = builder.saveInsertionPoint();
  }

  // If the substmt is default stmt or case stmt, try to handle the special case
  // to make it into the simple form. e.g.
  //
  //  swtich () {
  //    case 1:
  //    default:
  //      ...
  //  }
  //
  // we prefer generating
  //
  //  cir.switch() {
  //     cir.case(equal, 1) {
  //        cir.yield
  //     }
  //     cir.case(default) {
  //        ...
  //     }
  //  }
  //
  // than
  //
  //  cir.switch() {
  //     cir.case(equal, 1) {
  //       cir.case(default) {
  //         ...
  //       }
  //     }
  //  }
  //
  // We don't need to revert this if we find the current switch can't be in
  // simple form later since the conversion itself should be harmless.
  if (subStmtKind == SubStmtKind::Case) {
    result = emitCaseStmt(*cast<CaseStmt>(sub), condType, buildingTopLevelCase);
  } else if (subStmtKind == SubStmtKind::Default) {
    result = emitDefaultStmt(*cast<DefaultStmt>(sub), condType,
                             buildingTopLevelCase);
  } else if (buildingTopLevelCase) {
    // If we're building a top level case, try to restore the insert point to
    // the case we're building, then we can attach more random stmts to the
    // case to make generating `cir.switch` operation to be a simple form.
    builder.restoreInsertionPoint(insertPoint);
  }

  return result;
}

mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
                                                 mlir::Type condType,
                                                 bool buildingTopLevelCase) {
  cir::CaseOpKind kind;
  mlir::ArrayAttr value;
  llvm::APSInt intVal = s.getLHS()->EvaluateKnownConstInt(getContext());

  // If the case statement has an RHS value, it is representing a GNU
  // case range statement, where LHS is the beginning of the range
  // and RHS is the end of the range.
  if (const Expr *rhs = s.getRHS()) {
    llvm::APSInt endVal = rhs->EvaluateKnownConstInt(getContext());
    value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
                                  cir::IntAttr::get(condType, endVal)});
    kind = cir::CaseOpKind::Range;
  } else {
    value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
    kind = cir::CaseOpKind::Equal;
  }

  return emitCaseDefaultCascade(&s, condType, value, kind,
                                buildingTopLevelCase);
}

mlir::LogicalResult CIRGenFunction::emitDefaultStmt(const clang::DefaultStmt &s,
                                                    mlir::Type condType,
                                                    bool buildingTopLevelCase) {
  return emitCaseDefaultCascade(&s, condType, builder.getArrayAttr({}),
                                cir::CaseOpKind::Default, buildingTopLevelCase);
}

mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
                                                   bool buildingTopLevelCase) {
  assert(!condTypeStack.empty() &&
         "build switch case without specifying the type of the condition");

  if (s.getStmtClass() == Stmt::CaseStmtClass)
    return emitCaseStmt(cast<CaseStmt>(s), condTypeStack.back(),
                        buildingTopLevelCase);

  if (s.getStmtClass() == Stmt::DefaultStmtClass)
    return emitDefaultStmt(cast<DefaultStmt>(s), condTypeStack.back(),
                           buildingTopLevelCase);

  llvm_unreachable("expect case or default stmt");
}

mlir::LogicalResult
CIRGenFunction::emitCXXForRangeStmt(const CXXForRangeStmt &s,
                                    ArrayRef<const Attr *> forAttrs) {
  cir::ForOp forOp;

  // TODO(cir): pass in array of attributes.
  auto forStmtBuilder = [&]() -> mlir::LogicalResult {
    mlir::LogicalResult loopRes = mlir::success();
    // Evaluate the first pieces before the loop.
    if (s.getInit())
      if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
        return mlir::failure();
    if (emitStmt(s.getRangeStmt(), /*useCurrentScope=*/true).failed())
      return mlir::failure();
    if (emitStmt(s.getBeginStmt(), /*useCurrentScope=*/true).failed())
      return mlir::failure();
    if (emitStmt(s.getEndStmt(), /*useCurrentScope=*/true).failed())
      return mlir::failure();

    assert(!cir::MissingFeatures::loopInfoStack());
    // From LLVM: if there are any cleanups between here and the loop-exit
    // scope, create a block to stage a loop exit along.
    // We probably already do the right thing because of ScopeOp, but make
    // sure we handle all cases.
    assert(!cir::MissingFeatures::requiresCleanups());

    forOp = builder.createFor(
        getLoc(s.getSourceRange()),
        /*condBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          assert(!cir::MissingFeatures::createProfileWeightsForLoop());
          assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
          mlir::Value condVal = evaluateExprAsBool(s.getCond());
          builder.createCondition(condVal);
        },
        /*bodyBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          // https://en.cppreference.com/w/cpp/language/for
          // In C++ the scope of the init-statement and the scope of
          // statement are one and the same.
          bool useCurrentScope = true;
          if (emitStmt(s.getLoopVarStmt(), useCurrentScope).failed())
            loopRes = mlir::failure();
          if (emitStmt(s.getBody(), useCurrentScope).failed())
            loopRes = mlir::failure();
          emitStopPoint(&s);
        },
        /*stepBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          if (s.getInc())
            if (emitStmt(s.getInc(), /*useCurrentScope=*/true).failed())
              loopRes = mlir::failure();
          builder.createYield(loc);
        });
    return loopRes;
  };

  mlir::LogicalResult res = mlir::success();
  mlir::Location scopeLoc = getLoc(s.getSourceRange());
  cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                       [&](mlir::OpBuilder &b, mlir::Location loc) {
                         // Create a cleanup scope for the condition
                         // variable cleanups. Logical equivalent from
                         // LLVM codegn for LexicalScope
                         // ConditionScope(*this, S.getSourceRange())...
                         LexicalScope lexScope{*this, loc,
                                               builder.getInsertionBlock()};
                         res = forStmtBuilder();
                       });

  if (res.failed())
    return res;

  terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
  cir::ForOp forOp;

  // TODO: pass in an array of attributes.
  auto forStmtBuilder = [&]() -> mlir::LogicalResult {
    mlir::LogicalResult loopRes = mlir::success();
    // Evaluate the first part before the loop.
    if (s.getInit())
      if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
        return mlir::failure();
    assert(!cir::MissingFeatures::loopInfoStack());
    // In the classic codegen, if there are any cleanups between here and the
    // loop-exit scope, a block is created to stage the loop exit. We probably
    // already do the right thing because of ScopeOp, but we need more testing
    // to be sure we handle all cases.
    assert(!cir::MissingFeatures::requiresCleanups());

    forOp = builder.createFor(
        getLoc(s.getSourceRange()),
        /*condBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          assert(!cir::MissingFeatures::createProfileWeightsForLoop());
          assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
          mlir::Value condVal;
          if (s.getCond()) {
            // If the for statement has a condition scope,
            // emit the local variable declaration.
            if (s.getConditionVariable())
              emitDecl(*s.getConditionVariable());
            // C99 6.8.5p2/p4: The first substatement is executed if the
            // expression compares unequal to 0. The condition must be a
            // scalar type.
            condVal = evaluateExprAsBool(s.getCond());
          } else {
            condVal = cir::ConstantOp::create(b, loc, builder.getTrueAttr());
          }
          builder.createCondition(condVal);
        },
        /*bodyBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          // The scope of the for loop body is nested within the scope of the
          // for loop's init-statement and condition.
          if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
            loopRes = mlir::failure();
          emitStopPoint(&s);
        },
        /*stepBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          if (s.getInc())
            if (emitStmt(s.getInc(), /*useCurrentScope=*/true).failed())
              loopRes = mlir::failure();
          builder.createYield(loc);
        });
    return loopRes;
  };

  auto res = mlir::success();
  auto scopeLoc = getLoc(s.getSourceRange());
  cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                       [&](mlir::OpBuilder &b, mlir::Location loc) {
                         LexicalScope lexScope{*this, loc,
                                               builder.getInsertionBlock()};
                         res = forStmtBuilder();
                       });

  if (res.failed())
    return res;

  terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitDoStmt(const DoStmt &s) {
  cir::DoWhileOp doWhileOp;

  // TODO: pass in array of attributes.
  auto doStmtBuilder = [&]() -> mlir::LogicalResult {
    mlir::LogicalResult loopRes = mlir::success();
    assert(!cir::MissingFeatures::loopInfoStack());
    // From LLVM: if there are any cleanups between here and the loop-exit
    // scope, create a block to stage a loop exit along.
    // We probably already do the right thing because of ScopeOp, but make
    // sure we handle all cases.
    assert(!cir::MissingFeatures::requiresCleanups());

    doWhileOp = builder.createDoWhile(
        getLoc(s.getSourceRange()),
        /*condBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          assert(!cir::MissingFeatures::createProfileWeightsForLoop());
          assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
          // C99 6.8.5p2/p4: The first substatement is executed if the
          // expression compares unequal to 0. The condition must be a
          // scalar type.
          mlir::Value condVal = evaluateExprAsBool(s.getCond());
          builder.createCondition(condVal);
        },
        /*bodyBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          // The scope of the do-while loop body is a nested scope.
          if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
            loopRes = mlir::failure();
          emitStopPoint(&s);
        });
    return loopRes;
  };

  mlir::LogicalResult res = mlir::success();
  mlir::Location scopeLoc = getLoc(s.getSourceRange());
  cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                       [&](mlir::OpBuilder &b, mlir::Location loc) {
                         LexicalScope lexScope{*this, loc,
                                               builder.getInsertionBlock()};
                         res = doStmtBuilder();
                       });

  if (res.failed())
    return res;

  terminateBody(builder, doWhileOp.getBody(), getLoc(s.getEndLoc()));
  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitWhileStmt(const WhileStmt &s) {
  cir::WhileOp whileOp;

  // TODO: pass in array of attributes.
  auto whileStmtBuilder = [&]() -> mlir::LogicalResult {
    mlir::LogicalResult loopRes = mlir::success();
    assert(!cir::MissingFeatures::loopInfoStack());
    // From LLVM: if there are any cleanups between here and the loop-exit
    // scope, create a block to stage a loop exit along.
    // We probably already do the right thing because of ScopeOp, but make
    // sure we handle all cases.
    assert(!cir::MissingFeatures::requiresCleanups());

    whileOp = builder.createWhile(
        getLoc(s.getSourceRange()),
        /*condBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          assert(!cir::MissingFeatures::createProfileWeightsForLoop());
          assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
          mlir::Value condVal;
          // If the for statement has a condition scope,
          // emit the local variable declaration.
          if (s.getConditionVariable())
            emitDecl(*s.getConditionVariable());
          // C99 6.8.5p2/p4: The first substatement is executed if the
          // expression compares unequal to 0. The condition must be a
          // scalar type.
          condVal = evaluateExprAsBool(s.getCond());
          builder.createCondition(condVal);
        },
        /*bodyBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc) {
          // The scope of the while loop body is a nested scope.
          if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
            loopRes = mlir::failure();
          emitStopPoint(&s);
        });
    return loopRes;
  };

  mlir::LogicalResult res = mlir::success();
  mlir::Location scopeLoc = getLoc(s.getSourceRange());
  cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                       [&](mlir::OpBuilder &b, mlir::Location loc) {
                         LexicalScope lexScope{*this, loc,
                                               builder.getInsertionBlock()};
                         res = whileStmtBuilder();
                       });

  if (res.failed())
    return res;

  terminateBody(builder, whileOp.getBody(), getLoc(s.getEndLoc()));
  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitSwitchBody(const Stmt *s) {
  // It is rare but legal if the switch body is not a compound stmt. e.g.,
  //
  //  switch(a)
  //    while(...) {
  //      case1
  //      ...
  //      case2
  //      ...
  //    }
  if (!isa<CompoundStmt>(s))
    return emitStmt(s, /*useCurrentScope=*/true);

  auto *compoundStmt = cast<CompoundStmt>(s);

  mlir::Block *swtichBlock = builder.getBlock();
  for (auto *c : compoundStmt->body()) {
    if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
      builder.setInsertionPointToEnd(swtichBlock);
      // Reset insert point automatically, so that we can attach following
      // random stmt to the region of previous built case op to try to make
      // the being generated `cir.switch` to be in simple form.
      if (mlir::failed(
              emitSwitchCase(*switchCase, /*buildingTopLevelCase=*/true)))
        return mlir::failure();

      continue;
    }

    // Otherwise, just build the statements in the nearest case region.
    if (mlir::failed(emitStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c))))
      return mlir::failure();
  }

  return mlir::success();
}

mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) {
  // TODO: LLVM codegen does some early optimization to fold the condition and
  // only emit live cases. CIR should use MLIR to achieve similar things,
  // nothing to be done here.
  // if (ConstantFoldsToSimpleInteger(S.getCond(), ConstantCondValue))...
  assert(!cir::MissingFeatures::constantFoldSwitchStatement());

  SwitchOp swop;
  auto switchStmtBuilder = [&]() -> mlir::LogicalResult {
    if (s.getInit())
      if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
        return mlir::failure();

    if (s.getConditionVariable())
      emitDecl(*s.getConditionVariable(), /*evaluateConditionDecl=*/true);

    mlir::Value condV = emitScalarExpr(s.getCond());

    // TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
    assert(!cir::MissingFeatures::pgoUse());
    assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
    // TODO: if the switch has a condition wrapped by __builtin_unpredictable?
    assert(!cir::MissingFeatures::insertBuiltinUnpredictable());

    mlir::LogicalResult res = mlir::success();
    swop = SwitchOp::create(
        builder, getLoc(s.getBeginLoc()), condV,
        /*switchBuilder=*/
        [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
          curLexScope->setAsSwitch();

          condTypeStack.push_back(condV.getType());

          res = emitSwitchBody(s.getBody());

          condTypeStack.pop_back();
        });

    return res;
  };

  // The switch scope contains the full source range for SwitchStmt.
  mlir::Location scopeLoc = getLoc(s.getSourceRange());
  mlir::LogicalResult res = mlir::success();
  cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/
                       [&](mlir::OpBuilder &b, mlir::Location loc) {
                         LexicalScope lexScope{*this, loc,
                                               builder.getInsertionBlock()};
                         res = switchStmtBuilder();
                       });

  llvm::SmallVector<CaseOp> cases;
  swop.collectCases(cases);
  for (auto caseOp : cases)
    terminateBody(builder, caseOp.getCaseRegion(), caseOp.getLoc());
  terminateBody(builder, swop.getBody(), swop.getLoc());

  swop.setAllEnumCasesCovered(s.isAllEnumCasesCovered());

  return res;
}

void CIRGenFunction::emitReturnOfRValue(mlir::Location loc, RValue rv,
                                        QualType ty) {
  if (rv.isScalar()) {
    builder.createStore(loc, rv.getValue(), returnValue);
  } else if (rv.isAggregate()) {
    LValue dest = makeAddrLValue(returnValue, ty);
    LValue src = makeAddrLValue(rv.getAggregateAddress(), ty);
    emitAggregateCopy(dest, src, ty, getOverlapForReturnValue());
  } else {
    cgm.errorNYI(loc, "emitReturnOfRValue: complex return type");
  }

  // Classic codegen emits a branch through any cleanups before continuing to
  // a shared return block. Because CIR handles branching through cleanups
  // during the CFG flattening phase, we can just emit the return statement
  // directly.
  // TODO(cir): Eliminate this redundant load and the store above when we can.
  // Load the value from `__retval` and return it via the `cir.return` op.
  cir::AllocaOp retAlloca =
      mlir::cast<cir::AllocaOp>(fnRetAlloca->getDefiningOp());
  auto value = cir::LoadOp::create(builder, loc, retAlloca.getAllocaType(),
                                   *fnRetAlloca);

  cir::ReturnOp::create(builder, loc, {value});
}
