diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 12 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-member.slang | 28 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-member.slang.expected.txt | 5 |
7 files changed, 71 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 837dcb8eb..381efa2c7 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1291,6 +1291,16 @@ namespace Slang { checkExtensionExternVarAttribute(varDecl, extensionExternAttr); } + + // If a var decl has no_diff type, move the no_diff modifier from the type to the var. + if (auto modifiedType = as<ModifiedType>(varDecl->type.type)) + { + if (auto nodiffModifier = modifiedType->findModifier<NoDiffModifierVal>()) + { + varDecl->type.type = getRemovedModifierType(modifiedType, nodiffModifier); + addModifier(varDecl, m_astBuilder->getOrCreate<NoDiffModifier>()); + } + } } void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) @@ -1527,6 +1537,8 @@ namespace Slang // Go through all var members. for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>()) { + if (member->hasModifier<NoDiffModifier>()) + continue; auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); if (!diffType) continue; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 719706635..ccc739da3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -508,7 +508,7 @@ namespace Slang Type* TranslateTypeNode(Expr* node); TypeExp TranslateTypeNodeForced(TypeExp const& typeExp); TypeExp TranslateTypeNode(TypeExp const& typeExp); - + Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); DeclRefType* getExprDeclRefType(Expr * expr); /// Is `decl` usable as a static member? diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index d402dde03..1b2179144 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -84,6 +84,15 @@ namespace Slang return TranslateTypeNodeForced(typeExp); } + Type* SemanticsVisitor::getRemovedModifierType(ModifiedType* modifiedType, ModifierVal* modifier) + { + if (modifiedType->modifiers.getCount() == 1) + return modifiedType->base; + auto newModifiers = modifiedType->modifiers; + newModifiers.remove(modifier); + return m_astBuilder->getModifiedType(modifiedType->base, newModifiers); + } + Expr* SemanticsVisitor::ExpectATypeRepr(Expr* expr) { if (auto overloadedExpr = as<OverloadedExpr>(expr)) diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index c3af52d8a..eb8b09417 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -503,7 +503,8 @@ struct DiffUnzipPass if (as<IRParam>(inst)) { SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); - builder->setInsertBefore(as<IRBlock>(inst->getParent())->getFirstOrdinaryInst()); + auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); } else { @@ -516,7 +517,8 @@ struct DiffUnzipPass if (as<IRParam>(inst)) { SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); - builder->setInsertBefore(as<IRBlock>(inst->getParent())->getFirstOrdinaryInst()); + auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); } else { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f78dd39e5..aa2dc4efb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3320,6 +3320,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo getSimpleDefaultVal(IRType* type) { + type = (IRType*)unwrapAttributedType(type); if(auto basicType = as<IRBasicType>(type)) { switch( basicType->getBaseType() ) @@ -3355,8 +3356,18 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + Type* getOriginalTypeFromModifiedType(Type* type) + { + auto innerType = type; + while (auto modifiedType = as<ModifiedType>(innerType)) + innerType = modifiedType->base; + return innerType; + } + LoweredValInfo getDefaultVal(Type* type) { + type = getOriginalTypeFromModifiedType(type); + auto irType = lowerType(context, type); if (auto basicType = as<BasicExpressionType>(type)) { @@ -7909,6 +7920,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> bool isClassType(IRType* type) { + type = (IRType*)unwrapAttributedType(type); if (auto specialize = as<IRSpecialize>(type)) { return findSpecializeReturnVal(specialize)->getOp() == kIROp_ClassType; diff --git a/tests/autodiff/no-diff-member.slang b/tests/autodiff/no-diff-member.slang new file mode 100644 index 000000000..a91e84bf4 --- /dev/null +++ b/tests/autodiff/no-diff-member.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +struct A : IDifferentiable +{ + float x; + no_diff float y; +} + +[BackwardDifferentiable] +float f(A obj) +{ + return obj.y * obj.x * obj.x + obj.y * obj.y; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {2.0, 3.0}; + var p = diffPair(a); + let rs = __bwd_diff(f)(p, 1.0); + outputBuffer[0] = p.d.x; +} diff --git a/tests/autodiff/no-diff-member.slang.expected.txt b/tests/autodiff/no-diff-member.slang.expected.txt new file mode 100644 index 000000000..31af8a224 --- /dev/null +++ b/tests/autodiff/no-diff-member.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +12.000000 +0.000000 +0.000000 +0.000000
\ No newline at end of file |
