summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-20 14:42:50 -0800
committerGitHub <noreply@github.com>2023-02-20 14:42:50 -0800
commit47715e625337d489f3c0131bbc2b849378b48a5a (patch)
treebc737c8f03ef537b2ac39860bbb922c7600edc43 /source
parent8b05df4187117d61491f2fdbeb7d744146ad73f7 (diff)
Miscellaneous backward autodiff fixes. (#2665)
* Fix differentiable type registration * Fix use of non-differentiable return value in a differentiable func. * Fix use of primal inst that does not dominate the diff block. * Fix primal inst hoisting, and add missing type legalization logic. * Make `detach` defined on all differentiable T. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang24
-rw-r--r--source/slang/hlsl.meta.slang20
-rw-r--r--source/slang/slang-check-expr.cpp17
-rw-r--r--source/slang/slang-check-stmt.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp8
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp118
-rw-r--r--source/slang/slang-ir-autodiff-rev.h2
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp55
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h34
-rw-r--r--source/slang/slang-ir-autodiff.cpp24
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp11
-rw-r--r--source/slang/slang-ir-legalize-types.cpp125
-rw-r--r--source/slang/slang-ir-util.cpp30
-rw-r--r--source/slang/slang-ir-util.h6
-rw-r--r--source/slang/slang-ir.cpp14
16 files changed, 366 insertions, 135 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a60a77cc3..859b8a488 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -319,7 +319,13 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
// Detach and set derivatives to zero
-__generic<T : __BuiltinFloatingPointType>
+__generic<T : IDifferentiable>
+T detach(T x)
+{
+ return x;
+}
+
+__generic<T : IDifferentiable>
[ForwardDerivativeOf(detach)]
DifferentialPair<T> __d_detach(DifferentialPair<T> dpx)
{
@@ -329,27 +335,13 @@ DifferentialPair<T> __d_detach(DifferentialPair<T> dpx)
);
}
-__generic<T : __BuiltinFloatingPointType, let N : int>
-[ForwardDerivativeOf(detach)]
-DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> dpx)
-{
- VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx);
-}
-
-__generic<T : __BuiltinFloatingPointType>
+__generic<T : IDifferentiable>
[BackwardDerivativeOf(detach)]
void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut)
{
dpx = diffPair(dpx.p, T.dzero());
}
-__generic<T : __BuiltinFloatingPointType, let N : int>
-[BackwardDerivativeOf(detach)]
-void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
-{
- dpx = diffPair(dpx.p, vector<T, N>.dzero());
-}
-
// Natural Exponent
__generic<T : __BuiltinFloatingPointType>
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 1d2b327d2..7e75d06b3 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -770,26 +770,6 @@ struct TriangleStream
// Try to terminate the current draw or dispatch call (HLSL SM 4.0)
void abort();
-// Detach and set derivatives to zero
-
-__generic<T : __BuiltinFloatingPointType>
-T detach(T x)
-{
- return x;
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int>
-vector<T, N> detach(vector<T, N> x)
-{
- return x;
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
-matrix<T, N, M> detach(matrix<T, N, M> x)
-{
- return x;
-}
-
// Absolute value (HLSL SM 1.0)
__generic<T : __BuiltinIntegerType>
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 75a3c2ff1..3567e2593 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1024,6 +1024,23 @@ namespace Slang
});
}
}
+ for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer)
+ {
+ if (auto genSubst = as<GenericSubstitution>(subst))
+ {
+ for (auto arg : genSubst->getArgs())
+ {
+ if (auto typeArg = as<Type>(arg))
+ {
+ maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet);
+ }
+ }
+ }
+ else if (auto thisSubst = as<ThisTypeSubstitution>(subst))
+ {
+ maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet);
+ }
+ }
return;
}
}
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 519ca91ff..bc89dc94e 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -585,6 +585,7 @@ namespace Slang
void SemanticsStmtVisitor::checkLoopInDifferentiableFunc(Stmt* stmt)
{
+ SLANG_UNUSED(stmt);
if (getParentDifferentiableAttribute())
{
if (!getParentFunc())
@@ -601,16 +602,6 @@ namespace Slang
return;
if (getParentFunc()->findModifier<BackwardDerivativeAttribute>())
return;
-
- // For all ordinary differentiable functions, we require either a `[MaxIters]` attribute,
- // or a `[ForceUnroll]` attribet on loops.
- if (stmt->hasModifier<MaxItersAttribute>() || stmt->hasModifier<ForceUnrollAttribute>() || stmt->hasModifier<InferredMaxItersAttribute>())
- {
- }
- else
- {
- getSink()->diagnose(stmt, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters);
- }
}
}
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 26d84720f..564c33268 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -488,7 +488,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (!diffReturnType)
{
- diffReturnType = argBuilder.getVoidType();
+ diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
}
auto callInst = argBuilder.emitCallInst(
@@ -501,7 +501,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
*builder = afterBuilder;
- if (diffReturnType->getOp() != kIROp_VoidType)
+ if (diffReturnType->getOp() == kIROp_DifferentialPairType)
{
IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst);
auto diffType = differentiateType(&afterBuilder, origCall->getFullType());
@@ -510,8 +510,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
else
{
- // Return the inst itself if the return value is void.
- // This is fine since these values should never actually be used anywhere.
+ // Return the inst itself if the return value is non-differentiable.
+ // This is fine since these values should only be used by non-differentiable code.
//
return InstPair(callInst, callInst);
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index ef6178976..55c0ee46d 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -13,6 +13,7 @@
#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-init-local-var.h"
#include "slang-ir-redundancy-removal.h"
+#include "slang-ir-dominators.h"
namespace Slang
{
@@ -674,6 +675,106 @@ namespace Slang
return fwdDiffFunc;
}
+ void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc)
+ {
+ RefPtr<IRDominatorTree> domTree = computeDominatorTree(diffPropFunc);
+ auto firstBlock = diffPropFunc->getFirstBlock();
+ if (!firstBlock)
+ return;
+ Dictionary<IRInst*, IRVar*> instVars;
+ Dictionary<IRBlock*, IRCloneEnv> cloneEnvs;
+ auto storeInstAsLocalVar = [&](IRInst* inst)
+ {
+ IRVar* var = nullptr;
+ if (instVars.TryGetValue(inst, var))
+ return var;
+ IRBuilder builder(diffPropFunc);
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+ var = builder.emitVar(inst->getDataType());
+ builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType()));
+
+ setInsertAfterOrdinaryInst(&builder, inst);
+ builder.emitStore(var, inst);
+ instVars[inst] = var;
+ return var;
+ };
+
+ IRBuilder builder(diffPropFunc);
+ List<IRInst*> workList;
+ for (auto block : diffPropFunc->getBlocks())
+ {
+ if (!block->findDecoration<IRDifferentialInstDecoration>())
+ continue;
+ cloneEnvs[block] = IRCloneEnv();
+ for (auto inst : block->getChildren())
+ {
+ workList.add(inst);
+ }
+ }
+
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto inst = workList[i];
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto operand = inst->getOperand(j);
+ if (operand->getOp() == kIROp_Block)
+ continue;
+ auto operandParent = inst->getOperand(j)->getParent();
+ if (!operandParent)
+ continue;
+ if (operandParent->parent != diffPropFunc)
+ continue;
+ if (domTree->dominates(operandParent, inst->parent))
+ continue;
+
+ // The def site of the operand does not dominate the use.
+ // We need to insert a local variable to store this var.
+
+ IRInst* operandReplacement = nullptr;
+ if (canInstBeStored(operand))
+ {
+ auto var = storeInstAsLocalVar(operand);
+ builder.setInsertBefore(inst);
+ operandReplacement = builder.emitLoad(var);
+ }
+ else if (operand->getOp() == kIROp_Var)
+ {
+ // Var can just be hoisted to first block.
+ operand->insertBefore(firstBlock->getFirstOrdinaryInst());
+ }
+ else
+ {
+ // For all other insts, we need to copy it to right before this inst.
+ // Before actually copying it, check if we have already copied it to
+ // any blocks that dominates this block.
+ auto dom = as<IRBlock>(inst->getParent());
+ while (dom)
+ {
+ auto subCloneEnv = cloneEnvs.TryGetValue(dom);
+ if (!subCloneEnv) break;
+ if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement))
+ {
+ break;
+ }
+ dom = domTree->getImmediateDominator(dom);
+ }
+ // We have not found an existing clone in dominators, so we need to copy it
+ // to this block.
+ if (!operandReplacement)
+ {
+ auto subCloneEnv = cloneEnvs.TryGetValue(as<IRBlock>(inst->getParent()));
+ builder.setInsertBefore(inst);
+ operandReplacement = cloneInst(subCloneEnv, &builder, operand);
+ workList.add(operandReplacement);
+ }
+ }
+ if (operandReplacement)
+ builder.replaceOperand(inst->getOperands() + j, operandReplacement);
+ }
+ }
+ }
+
InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
{
SLANG_UNUSED(primalType);
@@ -838,6 +939,8 @@ namespace Slang
initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getModule(), diffPropagateFunc);
+ insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
+ stripTempDecorations(diffPropagateFunc);
}
ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
@@ -1151,16 +1254,15 @@ namespace Slang
while (refUse)
{
auto nextUse = refUse->nextUse;
- switch (refUse->getUser()->getOp())
+ // Is this use the dest operand of a store inst?
+ // If so, replace it with writeRefReplacement, otherwise, refReplacement.
+ if (refUse->getUser()->getOp() == kIROp_Store && refUse == refUse->getUser()->getOperands())
{
- case kIROp_Load:
- refUse->set(diffRefReplacement);
- break;
- case kIROp_Store:
+ SLANG_RELEASE_ASSERT(diffWriteRefReplacement);
refUse->set(diffWriteRefReplacement);
- break;
- default:
- SLANG_RELEASE_ASSERT(!diffWriteRefReplacement);
+ }
+ else
+ {
refUse->set(diffRefReplacement);
}
refUse = nextUse;
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index a638b873c..94bc1ef81 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -114,6 +114,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc);
+ void insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc);
+
void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc);
InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 614559c9f..4e7539b48 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -137,39 +137,8 @@ struct ExtractPrimalFuncContext
return false;
}
- // Only store allowed types.
- if (isScalarIntegerType(inst->getDataType()))
- {
- }
- else if (as<IRResourceTypeBase>(inst->getDataType()))
- {
- }
- else
- {
- switch (inst->getDataType()->getOp())
- {
- case kIROp_StructType:
- case kIROp_OptionalType:
- case kIROp_TupleType:
- case kIROp_ArrayType:
- case kIROp_DifferentialPairType:
- case kIROp_InterfaceType:
- case kIROp_AnyValueType:
- case kIROp_ClassType:
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- case kIROp_MatrixType:
- case kIROp_BoolType:
- case kIROp_Param:
- case kIROp_Specialize:
- case kIROp_LookupWitness:
- break;
- default:
- return false;
- }
- }
+ if (!canInstBeStored(inst))
+ return false;
// Never store certain opcodes.
switch (inst->getOp())
@@ -507,11 +476,19 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
else
{
// Orindary value.
- auto val = builder.emitFieldExtract(
- inst->getFullType(),
- intermediateVar,
- structKeyDecor->getStructKey());
- inst->replaceUsesWith(val);
+ // We insert a fieldExtract at each use site instead of before `inst`,
+ // since at this stage of autodiff pass, `inst` does not necessarily
+ // dominate all the use sites if `inst` is defined in partial branch
+ // in a primal block.
+ while (auto iuse = inst->firstUse)
+ {
+ builder.setInsertBefore(iuse->getUser());
+ auto val = builder.emitFieldExtract(
+ inst->getFullType(),
+ intermediateVar,
+ structKeyDecor->getStructKey());
+ iuse->set(val);
+ }
}
instsToRemove.add(inst);
}
@@ -529,8 +506,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
inst->removeAndDeallocate();
}
-
- stripTempDecorations(func);
return primalFunc;
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 2d5261b63..944df2c81 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -498,34 +498,6 @@ struct DiffUnzipPass
}
}
- void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
- {
- if (as<IRParam>(inst))
- {
- SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
- auto lastParam = as<IRBlock>(inst->getParent())->getLastParam();
- builder->setInsertAfter(lastParam);
- }
- else
- {
- builder->setInsertBefore(inst);
- }
- }
-
- void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst)
- {
- if (as<IRParam>(inst))
- {
- SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
- auto lastParam = as<IRBlock>(inst->getParent())->getLastParam();
- builder->setInsertAfter(lastParam);
- }
- else
- {
- builder->setInsertAfter(inst);
- }
- }
-
void processIndexedFwdBlock(IRBlock* fwdBlock)
{
if (!isBlockIndexed(fwdBlock))
@@ -794,6 +766,7 @@ struct DiffUnzipPass
auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs);
primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
+ primalBuilder->markInstAsPrimal(primalVal);
SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount());
@@ -1377,6 +1350,11 @@ struct DiffUnzipPass
// Remove insts that were split.
for (auto inst : splitInsts)
{
+ if (!isDifferentiableType(diffTypeContext, inst->getDataType()))
+ {
+ inst->replaceUsesWith(lookupPrimalInst(inst));
+ }
+
// Consistency check.
for (auto use = inst->firstUse; use; use = use->nextUse)
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 97cdb644e..b630b798d 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -563,6 +563,30 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst*
return false;
}
+bool canInstBeStored(IRInst* inst)
+{
+ if (as<IRBasicType>(inst->getDataType()))
+ return true;
+
+ switch (inst->getDataType()->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_OptionalType:
+ case kIROp_TupleType:
+ case kIROp_ArrayType:
+ case kIROp_DifferentialPairType:
+ case kIROp_InterfaceType:
+ case kIROp_AnyValueType:
+ case kIROp_ClassType:
+ case kIROp_FloatType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return true;
+ default:
+ return false;
+ }
+}
+
struct AutoDiffPass : public InstPassBase
{
DiagnosticSink* getSink()
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index f57fb2974..fa01d50ae 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -303,6 +303,8 @@ bool isBackwardDifferentiableFunc(IRInst* func);
bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
+bool canInstBeStored(IRInst* inst);
+
inline bool isRelevantDifferentialPair(IRType* type)
{
if (as<IRDifferentialPairType>(type))
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index c750b2d3d..1ee94e67e 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -353,6 +353,17 @@ public:
auto loop = as<IRLoop>(block->getTerminator());
if (!loop)
continue;
+ bool hasBackEdge = false;
+ for (auto use = loop->getTargetBlock()->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() != loop)
+ {
+ hasBackEdge = true;
+ break;
+ }
+ }
+ if (!hasBackEdge)
+ continue;
if (loop->findDecoration<IRLoopMaxItersDecoration>() || loop->findDecoration<IRForceUnrollDecoration>())
{
// We are good.
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 96ffb9bb2..7660c9526 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -701,6 +701,61 @@ static LegalVal legalizeRetVal(
return LegalVal();
}
+static void _addVal(List<IRInst*>& rs, const LegalVal& legalVal)
+{
+ switch (legalVal.flavor)
+ {
+ case LegalVal::Flavor::simple:
+ rs.add(legalVal.getSimple());
+ break;
+ case LegalVal::Flavor::tuple:
+ for (auto element : legalVal.getTuple()->elements)
+ _addVal(rs, element.val);
+ break;
+ case LegalVal::Flavor::pair:
+ _addVal(rs, legalVal.getPair()->ordinaryVal);
+ _addVal(rs, legalVal.getPair()->specialVal);
+ break;
+ case LegalVal::Flavor::none:
+ break;
+ default:
+ SLANG_UNEXPECTED("unhandled legalized val flavor");
+ }
+}
+
+static LegalVal legalizeUnconditionalBranch(
+ IRTypeLegalizationContext* context,
+ ArrayView<LegalVal> args,
+ IRUnconditionalBranch* branchInst)
+{
+ List<IRInst*> newArgs;
+ for (auto arg : args)
+ {
+ switch (arg.flavor)
+ {
+ case LegalVal::Flavor::none:
+ break;
+ case LegalVal::Flavor::simple:
+ newArgs.add(arg.getSimple());
+ break;
+ case LegalVal::Flavor::pair:
+ _addVal(newArgs, arg.getPair()->ordinaryVal);
+ _addVal(newArgs, arg.getPair()->specialVal);
+ break;
+ case LegalVal::Flavor::tuple:
+ for (auto element : arg.getTuple()->elements)
+ {
+ _addVal(newArgs, element.val);
+ }
+ break;
+ default:
+ SLANG_UNIMPLEMENTED_X("Unknown legalized val flavor.");
+ }
+ }
+ context->builder->emitBranch(branchInst->getTargetBlock(), newArgs.getCount() - 1, newArgs.getBuffer() + 1);
+ return LegalVal();
+}
+
static LegalVal legalizeLoad(
IRTypeLegalizationContext* context,
LegalVal legalPtrVal)
@@ -1610,11 +1665,69 @@ static LegalVal legalizeMakeStruct(
}
}
+static LegalVal legalizeDefaultConstruct(
+ IRTypeLegalizationContext* context,
+ LegalType legalType)
+{
+ auto builder = context->builder;
+
+ switch (legalType.flavor)
+ {
+ case LegalType::Flavor::none:
+ return LegalVal();
+
+ case LegalType::Flavor::simple:
+ {
+ return LegalVal::simple(
+ builder->emitDefaultConstruct(legalType.getSimple()));
+ }
+
+ case LegalType::Flavor::pair:
+ {
+ auto pairType = legalType.getPair();
+ auto pairInfo = pairType->pairInfo;
+ LegalType ordinaryType = pairType->ordinaryType;
+ LegalType specialType = pairType->specialType;
+
+ LegalVal ordinaryVal = legalizeDefaultConstruct(
+ context,
+ ordinaryType);
+
+ LegalVal specialVal = legalizeDefaultConstruct(
+ context,
+ specialType);
+
+ return LegalVal::pair(ordinaryVal, specialVal, pairInfo);
+ }
+ break;
+
+ case LegalType::Flavor::tuple:
+ {
+ auto tupleType = legalType.getTuple();
+
+ RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal();
+ for (auto typeElem : tupleType->elements)
+ {
+ auto elemKey = typeElem.key;
+ TuplePseudoVal::Element resElem;
+ resElem.key = elemKey;
+ resElem.val = legalizeDefaultConstruct(context, typeElem.type);
+ resTupleInfo->elements.add(resElem);
+ }
+ return LegalVal::tuple(resTupleInfo);
+ }
+
+ default:
+ SLANG_UNEXPECTED("unhandled");
+ UNREACHABLE_RETURN(LegalVal());
+ }
+}
+
static LegalVal legalizeInst(
IRTypeLegalizationContext* context,
IRInst* inst,
LegalType type,
- LegalVal const* args)
+ ArrayView<LegalVal> args)
{
switch (inst->getOp())
{
@@ -1647,8 +1760,14 @@ static LegalVal legalizeInst(
return legalizeMakeStruct(
context,
type,
- args,
+ args.getBuffer(),
inst->getOperandCount());
+ case kIROp_DefaultConstruct:
+ return legalizeDefaultConstruct(
+ context,
+ type);
+ case kIROp_unconditionalBranch:
+ return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst);
case kIROp_undefined:
return LegalVal();
case kIROp_GpuForeach:
@@ -1896,7 +2015,7 @@ static LegalVal legalizeInst(
context,
inst,
legalType,
- legalArgs.getBuffer());
+ legalArgs.getArrayView());
if (legalVal.flavor == LegalVal::Flavor::simple)
{
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index f3c4c2c82..3db036a8d 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -224,7 +224,7 @@ String dumpIRToString(IRInst* root)
StringBuilder sb;
StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
IRDumpOptions options = {};
-#if 0
+#if 1
options.flags = IRDumpOptions::Flag::DumpDebugIds;
#endif
dumpIR(root, options, nullptr, &writer);
@@ -487,6 +487,34 @@ IROp getSwapSideComparisonOp(IROp op)
}
}
+void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
+{
+ if (as<IRParam>(inst))
+ {
+ SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
+ auto lastParam = as<IRBlock>(inst->getParent())->getLastParam();
+ builder->setInsertAfter(lastParam);
+ }
+ else
+ {
+ builder->setInsertBefore(inst);
+ }
+}
+
+void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst)
+{
+ if (as<IRParam>(inst))
+ {
+ SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
+ auto lastParam = as<IRBlock>(inst->getParent())->getLastParam();
+ builder->setInsertAfter(lastParam);
+ }
+ else
+ {
+ builder->setInsertAfter(inst);
+ }
+}
+
bool isPureFunctionalCall(IRCall* call)
{
auto callee = getResolvedInstForDecorations(call->getCallee());
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 0fb26f791..8a12ab895 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -173,6 +173,12 @@ IRInst* getUndefInst(IRBuilder builder, IRModule* module);
// The the equivalent op of (a op b) in (b op' a). For example, a > b is equivalent to b < a. So (<) ==> (>).
IROp getSwapSideComparisonOp(IROp op);
+// Set IRBuilder to insert before `inst`. If `inst` is a param, it will insert after the last param.
+void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst);
+
+// Set IRBuilder to insert after `inst`. If `inst` is a param, it will insert after the last param.
+void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst);
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index dc61a45af..accefc0c9 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5662,9 +5662,9 @@ namespace Slang
#if SLANG_ENABLE_IR_BREAK_ALLOC
if (context->options.flags & IRDumpOptions::Flag::DumpDebugIds)
{
- dump(context, "[#");
+ dump(context, "{");
dump(context, String(inst->_debugUID));
- dump(context, "]");
+ dump(context, "}\t");
}
#else
SLANG_UNUSED(context);
@@ -5691,7 +5691,6 @@ namespace Slang
{
dump(context, "_");
}
- dumpDebugID(context, inst);
}
static void dumpEncodeString(
@@ -5819,6 +5818,7 @@ namespace Slang
IRBlock* block)
{
context->indent--;
+ dumpDebugID(context, block);
dump(context, "block ");
dumpID(context, block);
@@ -6050,7 +6050,6 @@ namespace Slang
}
dump(context, opInfo.name);
- dumpDebugID(context, inst);
dumpInstOperandList(context, inst);
}
@@ -6068,6 +6067,8 @@ namespace Slang
dumpIRDecorations(context, inst);
+ dumpDebugID(context, inst);
+
// There are several ops we want to special-case here,
// so that they will be more pleasant to look at.
//
@@ -6204,7 +6205,10 @@ namespace Slang
context.options = options;
context.sourceManager = sourceManager;
- dumpInst(&context, globalVal);
+ if (globalVal->getOp() == kIROp_Module)
+ dumpIRModule(&context, globalVal->getModule());
+ else
+ dumpInst(&context, globalVal);
writer->write(sb.getBuffer(), sb.getLength());
writer->flush();