diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-29 18:17:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-29 18:17:33 -0800 |
| commit | f52b4de3b29ee27213b7d60fb620a0d5d50b49f9 (patch) | |
| tree | d4570c53045bca8e9411e884b0905d9384430a58 | |
| parent | f5581786a1891cedb165adb1afe71fe34f26e030 (diff) | |
Allow `no_diff` modifier on parameters (#2538)
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 35 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 21 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 4 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-param.slang | 23 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-param.slang.expected.txt | 5 |
21 files changed, 209 insertions, 25 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index d8fec88ce..623a9161b 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -408,6 +408,11 @@ Val* ASTBuilder::getSNormModifierVal() return getOrCreate<SNormModifierVal>(); } +Val* ASTBuilder::getNoDiffModifierVal() +{ + return getOrCreate<NoDiffModifierVal>(); +} + TypeType* ASTBuilder::getTypeType(Type* type) { return getOrCreate<TypeType>(type); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index f011feae8..bdc03dda5 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -358,6 +358,7 @@ public: } Val* getUNormModifierVal(); Val* getSNormModifierVal(); + Val* getNoDiffModifierVal(); TypeType* getTypeType(Type* type); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f9a3fc393..2adbcf6c6 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1210,6 +1210,10 @@ class SNormModifier : public ResourceElementFormatModifier SLANG_AST_CLASS(SNormModifier) }; +class NoDiffModifier : public TypeModifier +{ + SLANG_AST_CLASS(NoDiffModifier) +}; } // namespace Slang diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index f19b71b56..0e7614dd6 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -850,6 +850,15 @@ class ModifiedType : public Type Type* base; List<Val*> modifiers; + template<typename T> + T* findModifier() + { + for (auto v : modifiers) + if (auto rs = as<T>(v)) + return rs; + return nullptr; + } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); bool _equalsImplOverride(Type* type); diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index a0f0552c6..e60c963a8 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -780,6 +780,20 @@ Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut return this; } +// NoDiffModifierVal +void NoDiffModifierVal::_toTextOverride(StringBuilder& out) +{ + out.append("no_diff"); +} + +Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + SLANG_UNUSED(astBuilder); + SLANG_UNUSED(subst); + SLANG_UNUSED(ioDiff); + return this; +} + // PolynomialIntVal bool PolynomialIntVal::_equalsValOverride(Val* val) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 31b74a499..503d63a76 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -490,6 +490,14 @@ class SNormModifierVal : public ResourceFormatModifierVal Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +class NoDiffModifierVal : public TypeModifierVal +{ + SLANG_AST_CLASS(NoDiffModifierVal) + + void _toTextOverride(StringBuilder& out); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); +}; + /// Represents the result of differentiating a function. class DifferentiateVal : public Val { diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 5e84e170b..c6daf5e86 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -579,6 +579,7 @@ namespace Slang case ASTNodeType::UNormModifierVal: case ASTNodeType::SNormModifierVal: + case ASTNodeType::NoDiffModifierVal: return true; } } @@ -597,6 +598,7 @@ namespace Slang case ASTNodeType::UNormModifierVal: case ASTNodeType::SNormModifierVal: + case ASTNodeType::NoDiffModifierVal: return true; } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 1f0e1a2dc..4b2d490b7 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2051,6 +2051,12 @@ namespace Slang Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) { + if (auto modifiedType = as<ModifiedType>(primalType)) + { + if (modifiedType->findModifier<NoDiffModifierVal>()) + return modifiedType->base; + } + // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = m_astBuilder->getDifferentiableInterface(); @@ -3386,6 +3392,10 @@ namespace Slang // TODO: validate that `type` is either `float` or a vector of `float`s return m_astBuilder->getSNormModifierVal(); } + else if (auto noDiffModifier = as<NoDiffModifier>(modifier)) + { + return m_astBuilder->getNoDiffModifierVal(); + } else { // TODO: more complete error message here diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index a84e40768..3a64f3c8f 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -583,6 +583,17 @@ namespace Slang return varDecl->getName(); } + Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& paramDeclRef) + { + auto paramType = getType(astBuilder, paramDeclRef); + if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>()) + { + auto modifierVal = static_cast<Val*>(astBuilder->getOrCreate<NoDiffModifierVal>()); + paramType = astBuilder->getModifiedType(paramType, 1, &modifierVal); + } + return paramType; + } + void Module::_collectShaderParams() { auto moduleDecl = m_moduleDecl; diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 0ad9ce87c..c9b186c8a 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -136,14 +136,14 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b newParameterTypes.add(origType); } - // Transcribe return type to a pair. - // This will be void if the primal return type is non-differentiable. - // - auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); - if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) - diffReturnType = returnPairType; - else - diffReturnType = origResultType; + // Transcribe return type to a pair. + // This will be void if the primal return type is non-differentiable. + // + auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); + if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) + diffReturnType = returnPairType; + else + diffReturnType = origResultType; return builder->getFuncType(newParameterTypes, diffReturnType); } @@ -354,9 +354,14 @@ InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRPar } } - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); + auto primalInst = cloneInst(&cloneEnv, builder, origParam); + if (auto primalParam = as<IRParam>(primalInst)) + { + SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); + primalParam->removeFromParent(); + builder->getInsertLoc().getBlock()->addParam(primalParam); + } + return InstPair(primalInst, nullptr); } else { @@ -368,7 +373,6 @@ InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRPar } return InstPair(primal, diff); } - } // Returns "d<var-name>" to use as a name hint for variables and parameters. diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 522c995b0..daf45e1ef 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -90,17 +90,33 @@ struct BackwardDiffTranscriber for (UIndex i = 0; i < funcType->getParamCount(); i++) { + bool noDiff = false; auto origType = funcType->getParamType(i); - if (auto diffPairType = tryGetDiffPairType(builder, origType)) + if (auto attrType = as<IRAttributedType>(origType)) { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - newParameterTypes.add(inoutDiffPairType); + if (attrType->findAttr<IRNoDiffAttr>()) + { + noDiff = true; + origType = attrType->getBaseType(); + } } - else + if (noDiff) + { newParameterTypes.add(origType); + } + else + { + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + newParameterTypes.add(inoutDiffPairType); + } + else + newParameterTypes.add(origType); + } } - newParameterTypes.add(funcType->getResultType()); + newParameterTypes.add(differentiateType(builder, funcType->getResultType())); diffReturnType = builder->getVoidType(); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index b0dbf62fa..4373cf44b 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -413,6 +413,36 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } +struct StripNoDiffTypeAttributePass : InstPassBase +{ + StripNoDiffTypeAttributePass(IRModule* module) : + InstPassBase(module) + { + } + void processModule() + { + processInstsOfType<IRAttributedType>(kIROp_AttributedType, [&](IRAttributedType* attrType) + { + if (attrType->getAllAttrs().getCount() == 1) + { + if (attrType->findAttr<IRNoDiffAttr>()) + { + attrType->replaceUsesWith(attrType->getBaseType()); + attrType->removeAndDeallocate(); + } + } + }); + sharedBuilderStorage.init(module); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + } +}; + +void stripNoDiffTypeAttribute(IRModule* module) +{ + StripNoDiffTypeAttributePass pass(module); + pass.processModule(); +} + bool processAutodiffCalls( IRModule* module, DiagnosticSink* sink, @@ -452,11 +482,14 @@ bool processAutodiffCalls( // modified |= processPairTypes(&autodiffContext); + stripNoDiffTypeAttribute(module); + // Remove auto-diff related decorations. stripAutoDiffDecorations(module); + return modified; } -}
\ No newline at end of file +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4aca291f9..c07200715 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -826,6 +826,8 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout) INST(CaseTypeLayoutAttr, caseLayout, 1, 0) INST(UNormAttr, unorm, 0, 0) INST(SNormAttr, snorm, 0, 0) + INST(NoDiffAttr, no_diff, 0, 0) + /* SemanticAttr */ INST(UserSemanticAttr, userSemantic, 2, 0) INST(SystemValueSemanticAttr, systemValueSemantic, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a1249aff9..c45d187f4 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -906,6 +906,11 @@ struct IRFuncThrowTypeAttr : IRAttr IRType* getErrorType() { return (IRType*)getOperand(0); } }; +struct IRNoDiffAttr : IRAttr +{ + IR_LEAF_ISA(NoDiffAttr) +}; + /// An attribute that specifies size information for a single resource kind. struct IRTypeSizeAttr : public IRLayoutResourceInfoAttr { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 0909615af..36fab6da1 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -716,20 +716,35 @@ struct IRInst void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); }; +inline bool isModifierInst(IROp op) +{ + switch (op) + { + case kIROp_AttributedType: + return true; + } + return false; +} + template<typename T> T* dynamicCast(IRInst* inst) { if (inst && T::isaImpl(inst->getOp())) return static_cast<T*>(inst); + if (inst) + { + if (isModifierInst(inst->getOp())) + { + return dynamicCast<T>(inst->getOperand(0)); + } + } return nullptr; } template<typename T> const T* dynamicCast(const IRInst* inst) { - if (inst && T::isaImpl(inst->getOp())) - return static_cast<const T*>(inst); - return nullptr; + return dynamicCast<T>(const_cast<IRInst*>(inst)); } // `dynamic_cast` equivalent (we just use dynamicCast) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4db9a479b..28639ae53 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2075,6 +2075,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(getBuilder()->getAttr(kIROp_SNormAttr)); } + LoweredValInfo visitNoDiffModifierVal(NoDiffModifierVal* astVal) + { + SLANG_UNUSED(astVal); + return LoweredValInfo::simple(getBuilder()->getAttr(kIROp_NoDiffAttr)); + } + // We do not expect to encounter the following types in ASTs that have // passed front-end semantic checking. #define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); } @@ -2783,7 +2789,7 @@ IRLoweringParameterInfo getParameterInfo( { IRLoweringParameterInfo info; - info.type = getType(context->astBuilder, paramDecl); + info.type = getParamType(context->astBuilder, paramDecl); info.decl = paramDecl; info.direction = getParameterDirection(paramDecl); info.isThisParam = false; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index ab849a98b..fd0810214 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1122,6 +1122,14 @@ namespace Slang AddModifier(&modifierLink, parsedModifier); continue; } + else if (AdvanceIf(parser, "no_diff")) + { + parsedModifier = parser->astBuilder->create<NoDiffModifier>(); + parsedModifier->keywordName = nameToken.getName(); + parsedModifier->loc = nameToken.loc; + AddModifier(&modifierLink, parsedModifier); + continue; + } // If there was no match for a modifier keyword, then we // must be at the end of the modifier sequence @@ -1459,7 +1467,7 @@ namespace Slang // Allow a declaration to use the keyword `void` for a parameter list, // since that was required in ancient C, and continues to be supported - // in a bunc hof its derivatives even if it is a Bad Design Choice + // in a bunch of its derivatives even if it is a Bad Design Choice // // TODO: conditionalize this so we don't keep this around for "pure" // Slang code diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 4f05bc936..a79c48227 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1120,7 +1120,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return astBuilder->create<NamedExpressionType>(specializedDeclRef); } - FuncType* getFuncType( ASTBuilder* astBuilder, @@ -1133,7 +1132,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt for (auto paramDeclRef : getParameters(declRef)) { auto paramDecl = paramDeclRef.getDecl(); - auto paramType = getType(astBuilder, paramDeclRef); + auto paramType = getParamType(astBuilder, paramDeclRef); if( paramDecl->findModifier<RefModifier>() ) { paramType = astBuilder->getRefType(paramType); diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 2ceb7a9fd..441dcb8e7 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -106,6 +106,10 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } + /// same as getType, but take into account the additional type modifiers from the parameter's modifier list + /// and return a ModifiedType if such modifiers exist. + Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& paramDeclRef); + inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->initExpr); diff --git a/tests/autodiff/no-diff-param.slang b/tests/autodiff/no-diff-param.slang new file mode 100644 index 000000000..b7c754889 --- /dev/null +++ b/tests/autodiff/no-diff-param.slang @@ -0,0 +1,23 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +[ForwardDifferentiable] +float f(float x, no_diff float y) +{ + return x * x + y * y; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + let rs = __fwd_diff(f)(dpfloat(1.5, 1.0), 2.0); + outputBuffer[0] = rs.p; // Expect: 6.25 + outputBuffer[1] = rs.d; // Expect: 3.0 + } +} diff --git a/tests/autodiff/no-diff-param.slang.expected.txt b/tests/autodiff/no-diff-param.slang.expected.txt new file mode 100644 index 000000000..b4bbbf1d4 --- /dev/null +++ b/tests/autodiff/no-diff-param.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +6.250000 +3.000000 +0.000000 +0.000000
\ No newline at end of file |
