diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-16 23:46:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-16 23:46:14 -0700 |
| commit | 9476d4543f4336a66308e55f722b0b0b2bd69dd2 (patch) | |
| tree | ff3a0514249f5c3975177bf053c5cb038e37acc8 /source | |
| parent | 77d3630eef4ea1c4b0424a46526a6be476a89230 (diff) | |
Fix Phi simplification bug. (#2710)
* Fix Phi simplification bug.
* Fix up.
* Fix.
* Fix.
* Fix.
* Fix.
* Fix.
* Fix test.
* Fix test.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 88 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 78 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-dominators.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.h | 3 |
16 files changed, 196 insertions, 110 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index ad3817d9a..0a3bb885e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -110,6 +110,14 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} +/// Modifer to mark a function for forward-mode differentiation. +/// i.e. the compiler will automatically generate a new function +/// that computes the jacobian-vector product of the original. +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; + +__attributeTarget(FunctionDeclBase) +attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; /// Interface to denote types as differentiable. /// Allows for user-specified differential types as @@ -137,6 +145,77 @@ interface IDifferentiable static Differential dmul(This, Differential); }; + +/// Pair type that serves to wrap the primal and +/// differential types of an arbitrary type T. + +__generic<T : IDifferentiable> +__magic_type(DifferentialPairType) +__intrinsic_type($(kIROp_DifferentialPairUserCodeType)) +struct DifferentialPair : IDifferentiable +{ + typedef DifferentialPair<T.Differential> Differential; + typedef T.Differential DifferentialElementType; + + __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) + __init(T _primal, T.Differential _differential); + + property p : T + { + __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode)) + get; + } + + property v : T + { + __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode)) + get; + } + + property d : T.Differential + { + __intrinsic_op($(kIROp_DifferentialPairGetDifferentialUserCode)) + get; + } + + [__unsafeForceInlineEarly] + T.Differential getDifferential() + { + return d; + } + + [__unsafeForceInlineEarly] + T getPrimal() + { + return p; + } + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(T.dzero(), T.Differential.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return Differential( + T.dadd( + a.p, + b.p + ), + T.Differential.dadd(a.d, b.d)); + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return Differential( + T.dmul(a.p, b.p), + T.Differential.dmul(a.d, b.d)); + } +}; + /// A type that can represent non-integers [sealed] [builtin] @@ -371,18 +450,21 @@ ${{{{ typedef $(kBaseTypes[tt].name) Differential; [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dzero() { return Differential(0); } [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { return a + b; } [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dmul(Differential a, Differential b) { return a * b; @@ -1072,18 +1154,21 @@ extension vector<T, N> : IDifferentiable typedef vector<T, N> Differential; [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dzero() { return Differential(__slang_noop_cast<T>(T.dzero())); } [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { return a + b; } [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dmul(This a, Differential b) { return a * b; @@ -1096,18 +1181,21 @@ extension matrix<T, R, C> : IDifferentiable typedef matrix<T, R, C> Differential; [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dzero() { return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero())); } [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { return a + b; } [__unsafeForceInlineEarly] + [BackwardDifferentiable] static Differential dmul(This a, Differential b) { return a * b; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ada052cd8..a9b8209f3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,8 +1,3 @@ -/// Modifer to mark a function for forward-mode differentiation. -/// i.e. the compiler will automatically generate a new function -/// that computes the jacobian-vector product of the original. -__attributeTarget(FunctionDeclBase) -attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) @@ -14,8 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; -__attributeTarget(FunctionDeclBase) -attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; @@ -33,77 +26,6 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; - -/// Pair type that serves to wrap the primal and -/// differential types of an arbitrary type T. - -__generic<T : IDifferentiable> -__magic_type(DifferentialPairType) -__intrinsic_type($(kIROp_DifferentialPairUserCodeType)) -struct DifferentialPair : IDifferentiable -{ - typedef DifferentialPair<T.Differential> Differential; - typedef T.Differential DifferentialElementType; - - __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) - __init(T _primal, T.Differential _differential); - - property p : T - { - __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode)) - get; - } - - property v : T - { - __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode)) - get; - } - - property d : T.Differential - { - __intrinsic_op($(kIROp_DifferentialPairGetDifferentialUserCode)) - get; - } - - [__unsafeForceInlineEarly] - T.Differential getDifferential() - { - return d; - } - - [__unsafeForceInlineEarly] - T getPrimal() - { - return p; - } - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return Differential(T.dzero(), T.Differential.dzero()); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return Differential( - T.dadd( - a.p, - b.p - ), - T.Differential.dadd(a.d, b.d)); - } - - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return Differential( - T.dmul(a.p, b.p), - T.Differential.dmul(a.d, b.d)); - } -}; - __generic<T: IDifferentiable> __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) DifferentialPair<T> diffPair(T primal, T.Differential diff); diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index 2af890d34..6568b4c83 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -25,6 +25,8 @@ public: { } + ASTBuilder* getBuilder() { return m_builder; } + Scope* getScope(ContainerDecl* decl) { for (auto container = decl; container; container = container->parentDecl) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ea8bec2bb..2613e6430 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3324,6 +3324,7 @@ namespace Slang { VarDecl* indexVar = nullptr; auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); + addModifier(forStmt, synth.getBuilder()->create<ForceUnrollAttribute>()); auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); for (auto& arg : args) { @@ -3358,6 +3359,7 @@ namespace Slang // We synthesize a memberwise dispatch to compute each field of `TResult`, // resulting an implementation of the form: // ``` + // [BackwardDifferentiable] // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) // { // TResult result; @@ -3389,6 +3391,9 @@ namespace Slang ThisExpr* synThis = nullptr; auto synFunc = synthesizeMethodSignatureForRequirementWitness( context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); + + addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>()); + synFunc->parentDecl = context->parentDecl; synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create<BlockStmt>(); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e9c156055..cf45a83f5 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1069,14 +1069,13 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig // will branch into the loop body) auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); - // Transcribe the break block (this is the block after the exiting the loop) - auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - // Transcribe the continue block (this is the 'update' part of the loop, which will // branch into the condition block) auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); - + // Transcribe the break block (this is the block after the exiting the loop) + auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + List<IRInst*> diffLoopOperands; diffLoopOperands.add(diffTargetBlock); diffLoopOperands.add(diffBreakBlock); @@ -1510,14 +1509,17 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr mapInOutParamToWriteBackValue.Clear(); - // Transcribe children from origFunc into diffFunc + // Create and map blocks in diff func. for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); + { + auto diffBlock = builder.emitBlock(); + mapPrimalInst(block, diffBlock); + mapDifferentialInst(block, diffBlock); + } - // Some of the transcribed blocks can appear 'out-of-order'. Although this - // shouldn't be an issue, for consistency, we put them back in order. + // Now actually transcribe the content of each block. for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock()) - as<IRBlock>(lookupDiffInst(block))->insertAtEnd(diffFunc); + this->transcribeBlock(&builder, block); for (auto block : diffFunc->getBlocks()) { diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 157011b7c..7c11a1286 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -556,7 +556,7 @@ namespace Slang // reversible. if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc))) return diffPropagateFunc; - + // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( autoDiffSharedContext->transcriberSet.forwardTranscriber); @@ -772,6 +772,9 @@ namespace Slang initializeLocalVariables(builder->getModule(), diffPropagateFunc); // insertVariableForRecomputedPrimalInsts(diffPropagateFunc); stripTempDecorations(diffPropagateFunc); + + sortBlocksInFunc(diffPropagateFunc); + sortBlocksInFunc(primalFunc); } ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index e3ef357ee..6784391a0 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -201,6 +201,7 @@ IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* b // Add method. IRBuilder b = *builder; b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); b.emitBlock(); @@ -273,6 +274,7 @@ IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRI // Add method. IRBuilder b = *builder; b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); IRType* paramTypes[2] = { diffArrayType, diffArrayType }; addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); b.emitBlock(); @@ -831,16 +833,10 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc IRBuilder subBuilder = *builder; subBuilder.setInsertLoc(builder->getInsertLoc()); - IRInst* diffBlock = subBuilder.emitBlock(); + IRInst* diffBlock = lookupDiffInst(origBlock); + SLANG_RELEASE_ASSERT(diffBlock); subBuilder.markInstAsMixedDifferential(diffBlock); - // Note: for blocks, we setup the mapping _before_ - // processing the children since we could encounter - // a lookup while processing the children. - // - mapPrimalInst(origBlock, diffBlock); - mapDifferentialInst(origBlock, diffBlock); - subBuilder.setInsertInto(diffBlock); // First transcribe every parameter in the block. @@ -1017,9 +1013,10 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene // Transcribe children from origFunc into diffFunc. builder.setInsertInto(diffGeneric); + auto bodyBlock = builder.emitBlock(); + mapPrimalInst(origGeneric->getFirstBlock(), bodyBlock); + mapDifferentialInst(origGeneric->getFirstBlock(), bodyBlock); auto transcribedBlock = transcribeBlockImpl(&builder, origGeneric->getFirstBlock(), instsToSkip); - mapPrimalInst(origGeneric->getFirstBlock(), transcribedBlock.primal); - mapDifferentialInst(origGeneric->getFirstBlock(), transcribedBlock.differential); return InstPair(primalGeneric, diffGeneric); } diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 7e01fde28..af7792748 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -474,7 +474,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( subEnv.squashChildrenMapping = true; subEnv.parent = &cloneEnv; auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); - auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv); // Remove [KeepAlive] decorations in clonedFunc. diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a3a7e4b77..e9b78696e 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -559,6 +559,9 @@ void stripNoDiffTypeAttribute(IRModule* module) bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) { + if (!typeInst) + return false; + if (context.isDifferentiableType((IRType*)typeInst)) return true; // Look for equivalent types. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 6f97ce076..14178a86c 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -194,7 +194,7 @@ public: return false; } - bool instHasNonTrivialDerivative(IRInst* inst) + bool instHasNonTrivialDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* inst) { switch (inst->getOp()) { @@ -206,7 +206,7 @@ public: return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); } default: - return true; + return isDifferentiableType(diffTypeContext, inst->getDataType()); } } @@ -468,7 +468,7 @@ public: if (auto storeInst = as<IRStore>(inst)) { if (produceDiffSet.Contains(storeInst->getVal()) && - instHasNonTrivialDerivative(storeInst->getVal()) && + instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp index 5f606092b..e03ae9425 100644 --- a/source/slang/slang-ir-dominators.cpp +++ b/source/slang/slang-ir-dominators.cpp @@ -616,6 +616,9 @@ struct DominatorTreeComputationContext RefPtr<IRDominatorTree> createDominatorTree(IRGlobalValueWithCode* code) { + if (code->getFirstBlock() == nullptr) + return nullptr; + // We first run the Cooper et al. algorithm to compute the `doms` array // which encodes immediate dominators. // diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 5d5a41726..65b5d2f45 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -1,6 +1,8 @@ #include "slang-ir-peephole.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-sccp.h" +#include "slang-ir-dominators.h" +#include "slang-ir-util.h" namespace Slang { @@ -290,6 +292,9 @@ struct PeepholeContext : InstPassBase return false; } + RefPtr<IRDominatorTree> domTree; + IRGlobalValueWithCode* domTreeFunc = nullptr; + void processInst(IRInst* inst) { if (as<IRGlobalValueWithCode>(inst)) @@ -679,9 +684,35 @@ struct PeepholeContext : InstPassBase { if (inst->hasUses()) { - inst->replaceUsesWith(argValue); - // Never remove param inst. - changed = true; + // Is argValue a global constant? + if (isChildInstOf(inst, argValue->getParent())) + { + inst->replaceUsesWith(argValue); + // Never remove param inst. + changed = true; + } + else + { + // If argValue is defined locally, + // we can replace only if argVal dominates inst. + auto parentFunc = getParentFunc(inst); + if (!parentFunc) + break; + if (domTreeFunc != parentFunc) + { + domTree = computeDominatorTree(parentFunc); + domTreeFunc = parentFunc; + } + if (!domTree) + break; + + if (domTree->dominates(argValue, inst)) + { + inst->replaceUsesWith(argValue); + // Never remove param inst. + changed = true; + } + } } } } @@ -694,7 +725,6 @@ struct PeepholeContext : InstPassBase bool processFunc(IRInst* func) { bool result = false; - for (;;) { changed = false; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index a27afee8a..bff80392f 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2,6 +2,7 @@ #include "slang-ir-insts.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -574,6 +575,13 @@ IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IR return loopParam; } +void sortBlocksInFunc(IRGlobalValueWithCode* func) +{ + auto order = getReversePostorder(func); + for (auto block : order) + block->insertAtEnd(func); +} + void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) { if (as<IRParam>(inst)) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 0989dee33..b2f49c24b 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -186,6 +186,7 @@ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst); // Returns the loop counter `IRParam`. IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock); +void sortBlocksInFunc(IRGlobalValueWithCode* func); } #endif diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 55e0f0168..3c720e929 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -11,6 +12,8 @@ namespace Slang // The IR module we are validating. IRModule* module; + RefPtr<IRDominatorTree> domTree; + // A diagnostic sink to send errors to if anything is invalid. DiagnosticSink* sink; @@ -165,9 +168,14 @@ namespace Slang // the same function (or another value with code). We need // to validate that `operandParentBlock` dominates `instParentBlock`. // - // TODO: implement this validation once we compute dominator trees. - // - // validate(context, operandParentBlock->dominates(instParentBlock), inst, "def must dominate use"); + if (context && context->domTree) + { + validate( + context, + context->domTree->dominates(operandParentBlock, instParentBlock), + inst, + "def must dominate use"); + } return; } } @@ -327,10 +335,24 @@ namespace Slang if (auto code = as<IRGlobalValueWithCode>(inst)) { + context->domTree = computeDominatorTree(code); validateCodeBody(context, code); + context->domTree = nullptr; } } + void validateIRInst(IRInst* inst) + { + IRValidateContext contextStorage; + IRValidateContext* context = &contextStorage; + DiagnosticSink sink; + context->module = inst->getModule(); + context->sink = &sink; + if (auto func = as<IRFunc>(inst)) + context->domTree = computeDominatorTree(func); + validateIRInst(context, inst); + } + void validateIRModule(IRModule* module, DiagnosticSink* sink) { IRValidateContext contextStorage; diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h index a1a9eb4f4..6dfacc158 100644 --- a/source/slang/slang-ir-validate.h +++ b/source/slang/slang-ir-validate.h @@ -7,7 +7,7 @@ namespace Slang class CompileRequestBase; class DiagnosticSink; struct IRModule; - + struct IRInst; // Validate that an IR module obeys the invariants we need to enforce. // For example: @@ -27,6 +27,7 @@ namespace Slang // // * Confirm that all the parameters of a block come before any "ordinary" instructions. void validateIRModule(IRModule* module, DiagnosticSink* sink); + void validateIRInst(IRInst* inst); // A wrapper that calls `validateIRModule` only when IR validation is enabled // for the given compile request. |
