summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-29 18:17:33 -0800
committerGitHub <noreply@github.com>2022-11-29 18:17:33 -0800
commitf52b4de3b29ee27213b7d60fb620a0d5d50b49f9 (patch)
treed4570c53045bca8e9411e884b0905d9384430a58
parentf5581786a1891cedb165adb1afe71fe34f26e030 (diff)
Allow `no_diff` modifier on parameters (#2538)
-rw-r--r--source/slang/slang-ast-builder.cpp5
-rw-r--r--source/slang/slang-ast-builder.h1
-rw-r--r--source/slang/slang-ast-modifier.h4
-rw-r--r--source/slang/slang-ast-type.h9
-rw-r--r--source/slang/slang-ast-val.cpp14
-rw-r--r--source/slang/slang-ast-val.h8
-rw-r--r--source/slang/slang-check-conversion.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp10
-rw-r--r--source/slang/slang-check-shader.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp28
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp26
-rw-r--r--source/slang/slang-ir-autodiff.cpp35
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-ir.h21
-rw-r--r--source/slang/slang-lower-to-ir.cpp8
-rw-r--r--source/slang/slang-parser.cpp10
-rw-r--r--source/slang/slang-syntax.cpp3
-rw-r--r--source/slang/slang-syntax.h4
-rw-r--r--tests/autodiff/no-diff-param.slang23
-rw-r--r--tests/autodiff/no-diff-param.slang.expected.txt5
21 files changed, 209 insertions, 25 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index d8fec88ce..623a9161b 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -408,6 +408,11 @@ Val* ASTBuilder::getSNormModifierVal()
return getOrCreate<SNormModifierVal>();
}
+Val* ASTBuilder::getNoDiffModifierVal()
+{
+ return getOrCreate<NoDiffModifierVal>();
+}
+
TypeType* ASTBuilder::getTypeType(Type* type)
{
return getOrCreate<TypeType>(type);
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index f011feae8..bdc03dda5 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -358,6 +358,7 @@ public:
}
Val* getUNormModifierVal();
Val* getSNormModifierVal();
+ Val* getNoDiffModifierVal();
TypeType* getTypeType(Type* type);
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index f9a3fc393..2adbcf6c6 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1210,6 +1210,10 @@ class SNormModifier : public ResourceElementFormatModifier
SLANG_AST_CLASS(SNormModifier)
};
+class NoDiffModifier : public TypeModifier
+{
+ SLANG_AST_CLASS(NoDiffModifier)
+};
} // namespace Slang
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index f19b71b56..0e7614dd6 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -850,6 +850,15 @@ class ModifiedType : public Type
Type* base;
List<Val*> modifiers;
+ template<typename T>
+ T* findModifier()
+ {
+ for (auto v : modifiers)
+ if (auto rs = as<T>(v))
+ return rs;
+ return nullptr;
+ }
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
bool _equalsImplOverride(Type* type);
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index a0f0552c6..e60c963a8 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -780,6 +780,20 @@ Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
return this;
}
+// NoDiffModifierVal
+void NoDiffModifierVal::_toTextOverride(StringBuilder& out)
+{
+ out.append("no_diff");
+}
+
+Val* NoDiffModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ SLANG_UNUSED(astBuilder);
+ SLANG_UNUSED(subst);
+ SLANG_UNUSED(ioDiff);
+ return this;
+}
+
// PolynomialIntVal
bool PolynomialIntVal::_equalsValOverride(Val* val)
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 31b74a499..503d63a76 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -490,6 +490,14 @@ class SNormModifierVal : public ResourceFormatModifierVal
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
+class NoDiffModifierVal : public TypeModifierVal
+{
+ SLANG_AST_CLASS(NoDiffModifierVal)
+
+ void _toTextOverride(StringBuilder& out);
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+};
+
/// Represents the result of differentiating a function.
class DifferentiateVal : public Val
{
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 5e84e170b..c6daf5e86 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -579,6 +579,7 @@ namespace Slang
case ASTNodeType::UNormModifierVal:
case ASTNodeType::SNormModifierVal:
+ case ASTNodeType::NoDiffModifierVal:
return true;
}
}
@@ -597,6 +598,7 @@ namespace Slang
case ASTNodeType::UNormModifierVal:
case ASTNodeType::SNormModifierVal:
+ case ASTNodeType::NoDiffModifierVal:
return true;
}
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 1f0e1a2dc..4b2d490b7 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2051,6 +2051,12 @@ namespace Slang
Type* SemanticsVisitor::getDifferentialPairType(Type* primalType)
{
+ if (auto modifiedType = as<ModifiedType>(primalType))
+ {
+ if (modifiedType->findModifier<NoDiffModifierVal>())
+ return modifiedType->base;
+ }
+
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = m_astBuilder->getDifferentiableInterface();
@@ -3386,6 +3392,10 @@ namespace Slang
// TODO: validate that `type` is either `float` or a vector of `float`s
return m_astBuilder->getSNormModifierVal();
}
+ else if (auto noDiffModifier = as<NoDiffModifier>(modifier))
+ {
+ return m_astBuilder->getNoDiffModifierVal();
+ }
else
{
// TODO: more complete error message here
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index a84e40768..3a64f3c8f 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -583,6 +583,17 @@ namespace Slang
return varDecl->getName();
}
+ Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& paramDeclRef)
+ {
+ auto paramType = getType(astBuilder, paramDeclRef);
+ if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>())
+ {
+ auto modifierVal = static_cast<Val*>(astBuilder->getOrCreate<NoDiffModifierVal>());
+ paramType = astBuilder->getModifiedType(paramType, 1, &modifierVal);
+ }
+ return paramType;
+ }
+
void Module::_collectShaderParams()
{
auto moduleDecl = m_moduleDecl;
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 0ad9ce87c..c9b186c8a 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -136,14 +136,14 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b
newParameterTypes.add(origType);
}
- // Transcribe return type to a pair.
- // This will be void if the primal return type is non-differentiable.
- //
- auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
- if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
- diffReturnType = returnPairType;
- else
- diffReturnType = origResultType;
+ // Transcribe return type to a pair.
+ // This will be void if the primal return type is non-differentiable.
+ //
+ auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
+ if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
+ diffReturnType = returnPairType;
+ else
+ diffReturnType = origResultType;
return builder->getFuncType(newParameterTypes, diffReturnType);
}
@@ -354,9 +354,14 @@ InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRPar
}
}
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
+ auto primalInst = cloneInst(&cloneEnv, builder, origParam);
+ if (auto primalParam = as<IRParam>(primalInst))
+ {
+ SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
+ primalParam->removeFromParent();
+ builder->getInsertLoc().getBlock()->addParam(primalParam);
+ }
+ return InstPair(primalInst, nullptr);
}
else
{
@@ -368,7 +373,6 @@ InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRPar
}
return InstPair(primal, diff);
}
-
}
// Returns "d<var-name>" to use as a name hint for variables and parameters.
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 522c995b0..daf45e1ef 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -90,17 +90,33 @@ struct BackwardDiffTranscriber
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
+ bool noDiff = false;
auto origType = funcType->getParamType(i);
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ if (auto attrType = as<IRAttributedType>(origType))
{
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- newParameterTypes.add(inoutDiffPairType);
+ if (attrType->findAttr<IRNoDiffAttr>())
+ {
+ noDiff = true;
+ origType = attrType->getBaseType();
+ }
}
- else
+ if (noDiff)
+ {
newParameterTypes.add(origType);
+ }
+ else
+ {
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ newParameterTypes.add(inoutDiffPairType);
+ }
+ else
+ newParameterTypes.add(origType);
+ }
}
- newParameterTypes.add(funcType->getResultType());
+ newParameterTypes.add(differentiateType(builder, funcType->getResultType()));
diffReturnType = builder->getVoidType();
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index b0dbf62fa..4373cf44b 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -413,6 +413,36 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
+struct StripNoDiffTypeAttributePass : InstPassBase
+{
+ StripNoDiffTypeAttributePass(IRModule* module) :
+ InstPassBase(module)
+ {
+ }
+ void processModule()
+ {
+ processInstsOfType<IRAttributedType>(kIROp_AttributedType, [&](IRAttributedType* attrType)
+ {
+ if (attrType->getAllAttrs().getCount() == 1)
+ {
+ if (attrType->findAttr<IRNoDiffAttr>())
+ {
+ attrType->replaceUsesWith(attrType->getBaseType());
+ attrType->removeAndDeallocate();
+ }
+ }
+ });
+ sharedBuilderStorage.init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ }
+};
+
+void stripNoDiffTypeAttribute(IRModule* module)
+{
+ StripNoDiffTypeAttributePass pass(module);
+ pass.processModule();
+}
+
bool processAutodiffCalls(
IRModule* module,
DiagnosticSink* sink,
@@ -452,11 +482,14 @@ bool processAutodiffCalls(
//
modified |= processPairTypes(&autodiffContext);
+ stripNoDiffTypeAttribute(module);
+
// Remove auto-diff related decorations.
stripAutoDiffDecorations(module);
+
return modified;
}
-} \ No newline at end of file
+}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 4aca291f9..c07200715 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -826,6 +826,8 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout)
INST(CaseTypeLayoutAttr, caseLayout, 1, 0)
INST(UNormAttr, unorm, 0, 0)
INST(SNormAttr, snorm, 0, 0)
+ INST(NoDiffAttr, no_diff, 0, 0)
+
/* SemanticAttr */
INST(UserSemanticAttr, userSemantic, 2, 0)
INST(SystemValueSemanticAttr, systemValueSemantic, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a1249aff9..c45d187f4 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -906,6 +906,11 @@ struct IRFuncThrowTypeAttr : IRAttr
IRType* getErrorType() { return (IRType*)getOperand(0); }
};
+struct IRNoDiffAttr : IRAttr
+{
+ IR_LEAF_ISA(NoDiffAttr)
+};
+
/// An attribute that specifies size information for a single resource kind.
struct IRTypeSizeAttr : public IRLayoutResourceInfoAttr
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 0909615af..36fab6da1 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -716,20 +716,35 @@ struct IRInst
void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent);
};
+inline bool isModifierInst(IROp op)
+{
+ switch (op)
+ {
+ case kIROp_AttributedType:
+ return true;
+ }
+ return false;
+}
+
template<typename T>
T* dynamicCast(IRInst* inst)
{
if (inst && T::isaImpl(inst->getOp()))
return static_cast<T*>(inst);
+ if (inst)
+ {
+ if (isModifierInst(inst->getOp()))
+ {
+ return dynamicCast<T>(inst->getOperand(0));
+ }
+ }
return nullptr;
}
template<typename T>
const T* dynamicCast(const IRInst* inst)
{
- if (inst && T::isaImpl(inst->getOp()))
- return static_cast<const T*>(inst);
- return nullptr;
+ return dynamicCast<T>(const_cast<IRInst*>(inst));
}
// `dynamic_cast` equivalent (we just use dynamicCast)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4db9a479b..28639ae53 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2075,6 +2075,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(getBuilder()->getAttr(kIROp_SNormAttr));
}
+ LoweredValInfo visitNoDiffModifierVal(NoDiffModifierVal* astVal)
+ {
+ SLANG_UNUSED(astVal);
+ return LoweredValInfo::simple(getBuilder()->getAttr(kIROp_NoDiffAttr));
+ }
+
// We do not expect to encounter the following types in ASTs that have
// passed front-end semantic checking.
#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); }
@@ -2783,7 +2789,7 @@ IRLoweringParameterInfo getParameterInfo(
{
IRLoweringParameterInfo info;
- info.type = getType(context->astBuilder, paramDecl);
+ info.type = getParamType(context->astBuilder, paramDecl);
info.decl = paramDecl;
info.direction = getParameterDirection(paramDecl);
info.isThisParam = false;
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index ab849a98b..fd0810214 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1122,6 +1122,14 @@ namespace Slang
AddModifier(&modifierLink, parsedModifier);
continue;
}
+ else if (AdvanceIf(parser, "no_diff"))
+ {
+ parsedModifier = parser->astBuilder->create<NoDiffModifier>();
+ parsedModifier->keywordName = nameToken.getName();
+ parsedModifier->loc = nameToken.loc;
+ AddModifier(&modifierLink, parsedModifier);
+ continue;
+ }
// If there was no match for a modifier keyword, then we
// must be at the end of the modifier sequence
@@ -1459,7 +1467,7 @@ namespace Slang
// Allow a declaration to use the keyword `void` for a parameter list,
// since that was required in ancient C, and continues to be supported
- // in a bunc hof its derivatives even if it is a Bad Design Choice
+ // in a bunch of its derivatives even if it is a Bad Design Choice
//
// TODO: conditionalize this so we don't keep this around for "pure"
// Slang code
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 4f05bc936..a79c48227 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1120,7 +1120,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return astBuilder->create<NamedExpressionType>(specializedDeclRef);
}
-
FuncType* getFuncType(
ASTBuilder* astBuilder,
@@ -1133,7 +1132,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
for (auto paramDeclRef : getParameters(declRef))
{
auto paramDecl = paramDeclRef.getDecl();
- auto paramType = getType(astBuilder, paramDeclRef);
+ auto paramType = getParamType(astBuilder, paramDeclRef);
if( paramDecl->findModifier<RefModifier>() )
{
paramType = astBuilder->getRefType(paramType);
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 2ceb7a9fd..441dcb8e7 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -106,6 +106,10 @@ namespace Slang
return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr());
}
+ /// same as getType, but take into account the additional type modifiers from the parameter's modifier list
+ /// and return a ModifiedType if such modifiers exist.
+ Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& paramDeclRef);
+
inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> const& declRef)
{
return declRef.substitute(astBuilder, declRef.getDecl()->initExpr);
diff --git a/tests/autodiff/no-diff-param.slang b/tests/autodiff/no-diff-param.slang
new file mode 100644
index 000000000..b7c754889
--- /dev/null
+++ b/tests/autodiff/no-diff-param.slang
@@ -0,0 +1,23 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+[ForwardDifferentiable]
+float f(float x, no_diff float y)
+{
+ return x * x + y * y;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ let rs = __fwd_diff(f)(dpfloat(1.5, 1.0), 2.0);
+ outputBuffer[0] = rs.p; // Expect: 6.25
+ outputBuffer[1] = rs.d; // Expect: 3.0
+ }
+}
diff --git a/tests/autodiff/no-diff-param.slang.expected.txt b/tests/autodiff/no-diff-param.slang.expected.txt
new file mode 100644
index 000000000..b4bbbf1d4
--- /dev/null
+++ b/tests/autodiff/no-diff-param.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+6.250000
+3.000000
+0.000000
+0.000000 \ No newline at end of file