summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-16 23:46:14 -0700
committerGitHub <noreply@github.com>2023-03-16 23:46:14 -0700
commit9476d4543f4336a66308e55f722b0b0b2bd69dd2 (patch)
treeff3a0514249f5c3975177bf053c5cb038e37acc8 /source
parent77d3630eef4ea1c4b0424a46526a6be476a89230 (diff)
Fix Phi simplification bug. (#2710)
* Fix Phi simplification bug. * Fix up. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix test. * Fix test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang88
-rw-r--r--source/slang/diff.meta.slang78
-rw-r--r--source/slang/slang-ast-synthesis.h2
-rw-r--r--source/slang/slang-check-decl.cpp5
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp20
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp5
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp17
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp1
-rw-r--r--source/slang/slang-ir-autodiff.cpp3
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp6
-rw-r--r--source/slang/slang-ir-dominators.cpp3
-rw-r--r--source/slang/slang-ir-peephole.cpp38
-rw-r--r--source/slang/slang-ir-util.cpp8
-rw-r--r--source/slang/slang-ir-util.h1
-rw-r--r--source/slang/slang-ir-validate.cpp28
-rw-r--r--source/slang/slang-ir-validate.h3
16 files changed, 196 insertions, 110 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index ad3817d9a..0a3bb885e 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -110,6 +110,14 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {}
interface __BuiltinIntegerType : __BuiltinArithmeticType
{}
+/// Modifer to mark a function for forward-mode differentiation.
+/// i.e. the compiler will automatically generate a new function
+/// that computes the jacobian-vector product of the original.
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
+
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
/// Interface to denote types as differentiable.
/// Allows for user-specified differential types as
@@ -137,6 +145,77 @@ interface IDifferentiable
static Differential dmul(This, Differential);
};
+
+/// Pair type that serves to wrap the primal and
+/// differential types of an arbitrary type T.
+
+__generic<T : IDifferentiable>
+__magic_type(DifferentialPairType)
+__intrinsic_type($(kIROp_DifferentialPairUserCodeType))
+struct DifferentialPair : IDifferentiable
+{
+ typedef DifferentialPair<T.Differential> Differential;
+ typedef T.Differential DifferentialElementType;
+
+ __intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
+ __init(T _primal, T.Differential _differential);
+
+ property p : T
+ {
+ __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode))
+ get;
+ }
+
+ property v : T
+ {
+ __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode))
+ get;
+ }
+
+ property d : T.Differential
+ {
+ __intrinsic_op($(kIROp_DifferentialPairGetDifferentialUserCode))
+ get;
+ }
+
+ [__unsafeForceInlineEarly]
+ T.Differential getDifferential()
+ {
+ return d;
+ }
+
+ [__unsafeForceInlineEarly]
+ T getPrimal()
+ {
+ return p;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(T.dzero(), T.Differential.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return Differential(
+ T.dadd(
+ a.p,
+ b.p
+ ),
+ T.Differential.dadd(a.d, b.d));
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return Differential(
+ T.dmul(a.p, b.p),
+ T.Differential.dmul(a.d, b.d));
+ }
+};
+
/// A type that can represent non-integers
[sealed]
[builtin]
@@ -371,18 +450,21 @@ ${{{{
typedef $(kBaseTypes[tt].name) Differential;
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dzero()
{
return Differential(0);
}
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dadd(Differential a, Differential b)
{
return a + b;
}
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dmul(Differential a, Differential b)
{
return a * b;
@@ -1072,18 +1154,21 @@ extension vector<T, N> : IDifferentiable
typedef vector<T, N> Differential;
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dzero()
{
return Differential(__slang_noop_cast<T>(T.dzero()));
}
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dadd(Differential a, Differential b)
{
return a + b;
}
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dmul(This a, Differential b)
{
return a * b;
@@ -1096,18 +1181,21 @@ extension matrix<T, R, C> : IDifferentiable
typedef matrix<T, R, C> Differential;
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dzero()
{
return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero()));
}
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dadd(Differential a, Differential b)
{
return a + b;
}
[__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
static Differential dmul(This a, Differential b)
{
return a * b;
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index ada052cd8..a9b8209f3 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1,8 +1,3 @@
-/// Modifer to mark a function for forward-mode differentiation.
-/// i.e. the compiler will automatically generate a new function
-/// that computes the jacobian-vector product of the original.
-__attributeTarget(FunctionDeclBase)
-attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
// Custom Forward Derivative Function reference
__attributeTarget(FunctionDeclBase)
@@ -14,8 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute;
-__attributeTarget(FunctionDeclBase)
-attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
@@ -33,77 +26,6 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
-
-/// Pair type that serves to wrap the primal and
-/// differential types of an arbitrary type T.
-
-__generic<T : IDifferentiable>
-__magic_type(DifferentialPairType)
-__intrinsic_type($(kIROp_DifferentialPairUserCodeType))
-struct DifferentialPair : IDifferentiable
-{
- typedef DifferentialPair<T.Differential> Differential;
- typedef T.Differential DifferentialElementType;
-
- __intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
- __init(T _primal, T.Differential _differential);
-
- property p : T
- {
- __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode))
- get;
- }
-
- property v : T
- {
- __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode))
- get;
- }
-
- property d : T.Differential
- {
- __intrinsic_op($(kIROp_DifferentialPairGetDifferentialUserCode))
- get;
- }
-
- [__unsafeForceInlineEarly]
- T.Differential getDifferential()
- {
- return d;
- }
-
- [__unsafeForceInlineEarly]
- T getPrimal()
- {
- return p;
- }
-
- [__unsafeForceInlineEarly]
- static Differential dzero()
- {
- return Differential(T.dzero(), T.Differential.dzero());
- }
-
- [__unsafeForceInlineEarly]
- static Differential dadd(Differential a, Differential b)
- {
- return Differential(
- T.dadd(
- a.p,
- b.p
- ),
- T.Differential.dadd(a.d, b.d));
- }
-
- [__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
- {
- return Differential(
- T.dmul(a.p, b.p),
- T.Differential.dmul(a.d, b.d));
- }
-};
-
__generic<T: IDifferentiable>
__intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
DifferentialPair<T> diffPair(T primal, T.Differential diff);
diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h
index 2af890d34..6568b4c83 100644
--- a/source/slang/slang-ast-synthesis.h
+++ b/source/slang/slang-ast-synthesis.h
@@ -25,6 +25,8 @@ public:
{
}
+ ASTBuilder* getBuilder() { return m_builder; }
+
Scope* getScope(ContainerDecl* decl)
{
for (auto container = decl; container; container = container->parentDecl)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index ea8bec2bb..2613e6430 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3324,6 +3324,7 @@ namespace Slang
{
VarDecl* indexVar = nullptr;
auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar);
+ addModifier(forStmt, synth.getBuilder()->create<ForceUnrollAttribute>());
auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar));
for (auto& arg : args)
{
@@ -3358,6 +3359,7 @@ namespace Slang
// We synthesize a memberwise dispatch to compute each field of `TResult`,
// resulting an implementation of the form:
// ```
+ // [BackwardDifferentiable]
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
// {
// TResult result;
@@ -3389,6 +3391,9 @@ namespace Slang
ThisExpr* synThis = nullptr;
auto synFunc = synthesizeMethodSignatureForRequirementWitness(
context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis);
+
+ addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>());
+
synFunc->parentDecl = context->parentDecl;
synth.pushContainerScope(synFunc);
auto blockStmt = m_astBuilder->create<BlockStmt>();
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e9c156055..cf45a83f5 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1069,14 +1069,13 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig
// will branch into the loop body)
auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
- // Transcribe the break block (this is the block after the exiting the loop)
- auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
-
// Transcribe the continue block (this is the 'update' part of the loop, which will
// branch into the condition block)
auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
-
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+
List<IRInst*> diffLoopOperands;
diffLoopOperands.add(diffTargetBlock);
diffLoopOperands.add(diffBreakBlock);
@@ -1510,14 +1509,17 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
mapInOutParamToWriteBackValue.Clear();
- // Transcribe children from origFunc into diffFunc
+ // Create and map blocks in diff func.
for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(&builder, block);
+ {
+ auto diffBlock = builder.emitBlock();
+ mapPrimalInst(block, diffBlock);
+ mapDifferentialInst(block, diffBlock);
+ }
- // Some of the transcribed blocks can appear 'out-of-order'. Although this
- // shouldn't be an issue, for consistency, we put them back in order.
+ // Now actually transcribe the content of each block.
for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock())
- as<IRBlock>(lookupDiffInst(block))->insertAtEnd(diffFunc);
+ this->transcribeBlock(&builder, block);
for (auto block : diffFunc->getBlocks())
{
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 157011b7c..7c11a1286 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -556,7 +556,7 @@ namespace Slang
// reversible.
if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc)))
return diffPropagateFunc;
-
+
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
autoDiffSharedContext->transcriberSet.forwardTranscriber);
@@ -772,6 +772,9 @@ namespace Slang
initializeLocalVariables(builder->getModule(), diffPropagateFunc);
// insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
stripTempDecorations(diffPropagateFunc);
+
+ sortBlocksInFunc(diffPropagateFunc);
+ sortBlocksInFunc(primalFunc);
}
ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index e3ef357ee..6784391a0 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -201,6 +201,7 @@ IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* b
// Add method.
IRBuilder b = *builder;
b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
b.emitBlock();
@@ -273,6 +274,7 @@ IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRI
// Add method.
IRBuilder b = *builder;
b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
IRType* paramTypes[2] = { diffArrayType, diffArrayType };
addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
b.emitBlock();
@@ -831,16 +833,10 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc
IRBuilder subBuilder = *builder;
subBuilder.setInsertLoc(builder->getInsertLoc());
- IRInst* diffBlock = subBuilder.emitBlock();
+ IRInst* diffBlock = lookupDiffInst(origBlock);
+ SLANG_RELEASE_ASSERT(diffBlock);
subBuilder.markInstAsMixedDifferential(diffBlock);
- // Note: for blocks, we setup the mapping _before_
- // processing the children since we could encounter
- // a lookup while processing the children.
- //
- mapPrimalInst(origBlock, diffBlock);
- mapDifferentialInst(origBlock, diffBlock);
-
subBuilder.setInsertInto(diffBlock);
// First transcribe every parameter in the block.
@@ -1017,9 +1013,10 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
// Transcribe children from origFunc into diffFunc.
builder.setInsertInto(diffGeneric);
+ auto bodyBlock = builder.emitBlock();
+ mapPrimalInst(origGeneric->getFirstBlock(), bodyBlock);
+ mapDifferentialInst(origGeneric->getFirstBlock(), bodyBlock);
auto transcribedBlock = transcribeBlockImpl(&builder, origGeneric->getFirstBlock(), instsToSkip);
- mapPrimalInst(origGeneric->getFirstBlock(), transcribedBlock.primal);
- mapDifferentialInst(origGeneric->getFirstBlock(), transcribedBlock.differential);
return InstPair(primalGeneric, diffGeneric);
}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 7e01fde28..af7792748 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -474,7 +474,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
subEnv.squashChildrenMapping = true;
subEnv.parent = &cloneEnv;
auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
-
auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv);
// Remove [KeepAlive] decorations in clonedFunc.
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index a3a7e4b77..e9b78696e 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -559,6 +559,9 @@ void stripNoDiffTypeAttribute(IRModule* module)
bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
{
+ if (!typeInst)
+ return false;
+
if (context.isDifferentiableType((IRType*)typeInst))
return true;
// Look for equivalent types.
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 6f97ce076..14178a86c 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -194,7 +194,7 @@ public:
return false;
}
- bool instHasNonTrivialDerivative(IRInst* inst)
+ bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst)
{
switch (inst->getOp())
{
@@ -206,7 +206,7 @@ public:
return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
}
default:
- return true;
+ return isDifferentiableType(diffTypeContext, inst->getDataType());
}
}
@@ -468,7 +468,7 @@ public:
if (auto storeInst = as<IRStore>(inst))
{
if (produceDiffSet.Contains(storeInst->getVal()) &&
- instHasNonTrivialDerivative(storeInst->getVal()) &&
+ instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp
index 5f606092b..e03ae9425 100644
--- a/source/slang/slang-ir-dominators.cpp
+++ b/source/slang/slang-ir-dominators.cpp
@@ -616,6 +616,9 @@ struct DominatorTreeComputationContext
RefPtr<IRDominatorTree> createDominatorTree(IRGlobalValueWithCode* code)
{
+ if (code->getFirstBlock() == nullptr)
+ return nullptr;
+
// We first run the Cooper et al. algorithm to compute the `doms` array
// which encodes immediate dominators.
//
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 5d5a41726..65b5d2f45 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -1,6 +1,8 @@
#include "slang-ir-peephole.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-sccp.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -290,6 +292,9 @@ struct PeepholeContext : InstPassBase
return false;
}
+ RefPtr<IRDominatorTree> domTree;
+ IRGlobalValueWithCode* domTreeFunc = nullptr;
+
void processInst(IRInst* inst)
{
if (as<IRGlobalValueWithCode>(inst))
@@ -679,9 +684,35 @@ struct PeepholeContext : InstPassBase
{
if (inst->hasUses())
{
- inst->replaceUsesWith(argValue);
- // Never remove param inst.
- changed = true;
+ // Is argValue a global constant?
+ if (isChildInstOf(inst, argValue->getParent()))
+ {
+ inst->replaceUsesWith(argValue);
+ // Never remove param inst.
+ changed = true;
+ }
+ else
+ {
+ // If argValue is defined locally,
+ // we can replace only if argVal dominates inst.
+ auto parentFunc = getParentFunc(inst);
+ if (!parentFunc)
+ break;
+ if (domTreeFunc != parentFunc)
+ {
+ domTree = computeDominatorTree(parentFunc);
+ domTreeFunc = parentFunc;
+ }
+ if (!domTree)
+ break;
+
+ if (domTree->dominates(argValue, inst))
+ {
+ inst->replaceUsesWith(argValue);
+ // Never remove param inst.
+ changed = true;
+ }
+ }
}
}
}
@@ -694,7 +725,6 @@ struct PeepholeContext : InstPassBase
bool processFunc(IRInst* func)
{
bool result = false;
-
for (;;)
{
changed = false;
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index a27afee8a..bff80392f 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2,6 +2,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
+#include "slang-ir-dominators.h"
namespace Slang
{
@@ -574,6 +575,13 @@ IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IR
return loopParam;
}
+void sortBlocksInFunc(IRGlobalValueWithCode* func)
+{
+ auto order = getReversePostorder(func);
+ for (auto block : order)
+ block->insertAtEnd(func);
+}
+
void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
{
if (as<IRParam>(inst))
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 0989dee33..b2f49c24b 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -186,6 +186,7 @@ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst);
// Returns the loop counter `IRParam`.
IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock);
+void sortBlocksInFunc(IRGlobalValueWithCode* func);
}
#endif
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 55e0f0168..3c720e929 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -3,6 +3,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir-dominators.h"
namespace Slang
{
@@ -11,6 +12,8 @@ namespace Slang
// The IR module we are validating.
IRModule* module;
+ RefPtr<IRDominatorTree> domTree;
+
// A diagnostic sink to send errors to if anything is invalid.
DiagnosticSink* sink;
@@ -165,9 +168,14 @@ namespace Slang
// the same function (or another value with code). We need
// to validate that `operandParentBlock` dominates `instParentBlock`.
//
- // TODO: implement this validation once we compute dominator trees.
- //
- // validate(context, operandParentBlock->dominates(instParentBlock), inst, "def must dominate use");
+ if (context && context->domTree)
+ {
+ validate(
+ context,
+ context->domTree->dominates(operandParentBlock, instParentBlock),
+ inst,
+ "def must dominate use");
+ }
return;
}
}
@@ -327,10 +335,24 @@ namespace Slang
if (auto code = as<IRGlobalValueWithCode>(inst))
{
+ context->domTree = computeDominatorTree(code);
validateCodeBody(context, code);
+ context->domTree = nullptr;
}
}
+ void validateIRInst(IRInst* inst)
+ {
+ IRValidateContext contextStorage;
+ IRValidateContext* context = &contextStorage;
+ DiagnosticSink sink;
+ context->module = inst->getModule();
+ context->sink = &sink;
+ if (auto func = as<IRFunc>(inst))
+ context->domTree = computeDominatorTree(func);
+ validateIRInst(context, inst);
+ }
+
void validateIRModule(IRModule* module, DiagnosticSink* sink)
{
IRValidateContext contextStorage;
diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h
index a1a9eb4f4..6dfacc158 100644
--- a/source/slang/slang-ir-validate.h
+++ b/source/slang/slang-ir-validate.h
@@ -7,7 +7,7 @@ namespace Slang
class CompileRequestBase;
class DiagnosticSink;
struct IRModule;
-
+ struct IRInst;
// Validate that an IR module obeys the invariants we need to enforce.
// For example:
@@ -27,6 +27,7 @@ namespace Slang
//
// * Confirm that all the parameters of a block come before any "ordinary" instructions.
void validateIRModule(IRModule* module, DiagnosticSink* sink);
+ void validateIRInst(IRInst* inst);
// A wrapper that calls `validateIRModule` only when IR validation is enabled
// for the given compile request.