diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-30 13:24:39 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-30 13:24:39 -0800 |
| commit | 09684224d5ab63f530d66c0be65fa50e6fc5290b (patch) | |
| tree | 292d0f257b3d5a5e027892a5a1e046d60166aadd | |
| parent | f52b4de3b29ee27213b7d60fb620a0d5d50b49f9 (diff) | |
Support `no_diff` on existential typed params. (#2540)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-check-conformance.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 41 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 37 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 17 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-param-2.slang | 38 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-param-2.slang.expected.txt | 5 |
12 files changed, 112 insertions, 62 deletions
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index d2335efbf..4d983b746 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -87,8 +87,10 @@ namespace Slang // that `subType` has been proven to be *equal* // to `superTypeDeclRef`. // - SLANG_UNEXPECTED("reflexive type witness"); - UNREACHABLE_RETURN(nullptr); + auto witness = m_astBuilder->create<TypeEqualityWitness>(); + witness->sub = subType; + witness->sup = subType; + return witness; } // We might have one or more steps in the breadcrumb trail, e.g.: diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5e6c6eedf..d36e6286d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4715,7 +4715,8 @@ namespace Slang maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type); if (as<ConstructorDecl>(decl) || !isEffectivelyStatic(decl)) { - auto thisType = calcThisType(makeDeclRef(decl)); + auto parentDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl->parentDecl)); + auto thisType = calcThisType(parentDeclRef); maybeRegisterDifferentiableType(m_astBuilder, thisType); } m_parentDifferentiableAttr = oldAttr; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ca55a68bc..508402736 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -358,33 +358,36 @@ Result linkAndOptimizeIR( // perform specialization of functions based on parameter // values that need to be compile-time constants. // + // Specialization passes and auto-diff passes runs in an iterative loop + // since each pass can enable the other pass to progress further. + for (;;) + { + bool changed = false; - dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE"); - if (!codeGenContext->isSpecializationDisabled()) - specializeModule(irModule); - dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); - - applySparseConditionalConstantPropagation(irModule); - eliminateDeadCode(irModule); + dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE"); + if (!codeGenContext->isSpecializationDisabled()) + changed |= specializeModule(irModule); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); - lowerReinterpret(targetRequest, irModule, sink); - - validateIRModuleIfEnabled(codeGenContext, irModule); + validateIRModuleIfEnabled(codeGenContext, irModule); - // Inline calls to any functions marked with [__unsafeInlineEarly] again, - // since we may be missing out cases prevented by the functions that we just specialzied. - performMandatoryEarlyInlining(irModule); + // Inline calls to any functions marked with [__unsafeInlineEarly] again, + // since we may be missing out cases prevented by the functions that we just specialzied. + performMandatoryEarlyInlining(irModule); - dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); - - processAutodiffCalls(irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); + changed |= processAutodiffCalls(irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); - dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); + if (!changed) + break; + } + + lowerReinterpret(targetRequest, irModule, sink); validateIRModuleIfEnabled(codeGenContext, irModule); - applySparseConditionalConstantPropagation(irModule); - eliminateDeadCode(irModule); + simplifyIR(irModule); // For targets that supports dynamic dispatch, we need to lower the // generics / interface types to ordinary functions and types using diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c9b186c8a..d45dd0c10 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -691,7 +691,8 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), primalCallee); } - else + + if (!diffCallee) { // The callee is non differentiable, just return primal value with null diff value. IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); @@ -1614,8 +1615,8 @@ struct ForwardDerivativePass : public InstPassBase // bool processReferencedFunctions(IRBuilder* builder) { + bool changed = false; List<IRInst*> autoDiffWorkList; - for (;;) { // Collect all `ForwardDifferentiate` insts from the module. @@ -1669,6 +1670,7 @@ struct ForwardDerivativePass : public InstPassBase differentiateInst->replaceUsesWith(diffFunc); differentiateInst->removeAndDeallocate(); } + changed = true; } } // Actually synthesize the derivatives. @@ -1689,7 +1691,7 @@ struct ForwardDerivativePass : public InstPassBase SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0); } - return true; + return changed; } // Checks decorators to see if the function should diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index 1dbb1bd7c..b9b4a8b66 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -133,12 +133,10 @@ struct DiffPairLoweringPass : InstPassBase case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: lowerPairAccess(builder, inst); - modified = true; break; case kIROp_MakeDifferentialPair: lowerMakePair(builder, inst); - modified = true; break; default: @@ -152,6 +150,7 @@ struct DiffPairLoweringPass : InstPassBase { inst->replaceUsesWith(loweredType); inst->removeAndDeallocate(); + modified = true; } }); return modified; @@ -179,4 +178,4 @@ bool processPairTypes(AutoDiffSharedContext* context) return pairLoweringPass.processModule(); } -}
\ No newline at end of file +} diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index daf45e1ef..8ec8f581c 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -871,6 +871,8 @@ struct ReverseDerivativePass : public InstPassBase // bool processReferencedFunctions(IRBuilder* builder) { + bool changed = false; + List<IRInst*> autoDiffWorkList; for (;;) @@ -922,6 +924,7 @@ struct ReverseDerivativePass : public InstPassBase SLANG_ASSERT(diffFunc); differentiateInst->replaceUsesWith(diffFunc); differentiateInst->removeAndDeallocate(); + changed = true; } else { @@ -950,7 +953,7 @@ struct ReverseDerivativePass : public InstPassBase SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0); } - return true; + return changed; } // Checks decorators to see if the function should diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4373cf44b..5b5832073 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -448,12 +448,6 @@ bool processAutodiffCalls( DiagnosticSink* sink, IRAutodiffPassOptions const&) { - // Simplify module to remove dead code. - IRDeadCodeEliminationOptions dceOptions; - dceOptions.keepExportsAlive = true; - dceOptions.keepLayoutsAlive = true; - eliminateDeadCode(module, dceOptions); - bool modified = false; // Create shared context for all auto-diff related passes @@ -487,7 +481,6 @@ bool processAutodiffCalls( // Remove auto-diff related decorations. stripAutoDiffDecorations(module); - return modified; } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 406e5157c..74caa30ae 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -44,6 +44,8 @@ struct SpecializationContext // we are specializing. IRModule* module; + bool changed = false; + // We know that we can only perform generic specialization when all // of the arguments to a generic are also fully specialized. // The "is fully specialized" condition is something we @@ -793,8 +795,6 @@ struct SpecializationContext SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); - bool changed = true; - // Read specialization dictionary from module if it is defined. // This prevents us from generating duplicated specializations // when this pass is invoked iteratively. @@ -839,9 +839,9 @@ struct SpecializationContext // We start out simple by putting the root instruction for the // module onto our work list. // - while (changed) + for (;;) { - changed = false; + bool iterChanged = false; addToWorkList(module->getModuleInst()); while (workList.Count() != 0) @@ -868,7 +868,7 @@ struct SpecializationContext // specialization opportunities (generic specialization, // existential specialization, simplifications, etc.) // - changed |= maybeSpecializeInst(inst); + iterChanged |= maybeSpecializeInst(inst); // Finally, we need to make our logic recurse through // the whole IR module, so we want to add the children @@ -896,8 +896,15 @@ struct SpecializationContext addDirtyInstsToWorkListRec(module->getModuleInst()); } - if (changed) + if (iterChanged) + { simplifyIR(module); + this->changed = true; + } + else + { + break; + } } // Once the work list has gone dry, we should have the invariant @@ -1776,6 +1783,11 @@ struct SpecializationContext type = sbType->getElementType(); goto top; } + else if (auto attributedType = as<IRAttributedType>(type)) + { + type = attributedType->getBaseType(); + goto top; + } else if( auto structType = as<IRStructType>(type) ) { UInt count = 0; @@ -2070,6 +2082,11 @@ struct SpecializationContext type = sbType->getElementType(); goto top; } + else if (auto attributedType = as<IRAttributedType>(type)) + { + type = attributedType->getBaseType(); + goto top; + } else if( auto structType = as<IRStructType>(type) ) { UInt count = 0; @@ -2114,7 +2131,8 @@ struct SpecializationContext } else if( as<IRPointerLikeType>(baseType) || as<IRHLSLStructuredBufferTypeBase>(baseType) || - as<IRArrayTypeBase>(baseType)) + as<IRArrayTypeBase>(baseType) || + as<IRAttributedType>(baseType) ) { // A `BindExistentials<P<T>, ...>` can be simplified to // `P<BindExistentials<T, ...>>` when `P` is a pointer-like @@ -2127,6 +2145,8 @@ struct SpecializationContext baseElementType = arrayType->getElementType(); else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType)) baseElementType = baseSBType->getElementType(); + else if (auto baseAttrType = as<IRAttributedType>(baseType)) + baseElementType = baseAttrType->getBaseType(); IRInst* wrappedElementType = builder.getBindExistentialsType( baseElementType, @@ -2283,12 +2303,13 @@ struct SpecializationContext } }; -void specializeModule( +bool specializeModule( IRModule* module) { SpecializationContext context; context.module = module; context.processModule(); + return context.changed; } diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index 9c2c19785..1503c238e 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -6,7 +6,7 @@ namespace Slang struct IRModule; /// Specialize generic and interface-based code to use concrete types. -void specializeModule( +bool specializeModule( IRModule* module); } diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 36fab6da1..56a33c02b 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -716,28 +716,11 @@ struct IRInst void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); }; -inline bool isModifierInst(IROp op) -{ - switch (op) - { - case kIROp_AttributedType: - return true; - } - return false; -} - template<typename T> T* dynamicCast(IRInst* inst) { if (inst && T::isaImpl(inst->getOp())) return static_cast<T*>(inst); - if (inst) - { - if (isModifierInst(inst->getOp())) - { - return dynamicCast<T>(inst->getOperand(0)); - } - } return nullptr; } diff --git a/tests/autodiff/no-diff-param-2.slang b/tests/autodiff/no-diff-param-2.slang new file mode 100644 index 000000000..d29928d69 --- /dev/null +++ b/tests/autodiff/no-diff-param-2.slang @@ -0,0 +1,38 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +interface IFoo : IDifferentiable +{ + [ForwardDifferentiable] + float getVal(); +} + +struct A : IFoo +{ + float x; + [ForwardDifferentiable] + float getVal(){return x;} +} + +[ForwardDifferentiable] +float f(float x, no_diff IFoo y) +{ + return x * x + y.getVal(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + A a; + a.x = 2.0; + let rs = __fwd_diff(f)(dpfloat(1.5, 1.0), a); + outputBuffer[0] = rs.p; // Expect: 6.25 + outputBuffer[1] = rs.d; // Expect: 3.0 + } +} diff --git a/tests/autodiff/no-diff-param-2.slang.expected.txt b/tests/autodiff/no-diff-param-2.slang.expected.txt new file mode 100644 index 000000000..18066089d --- /dev/null +++ b/tests/autodiff/no-diff-param-2.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +4.250000 +3.000000 +0.000000 +0.000000
\ No newline at end of file |
