From e93cb8a4d1bb7d835bc3762ce25ced422e75e97a Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 20 Dec 2024 00:53:49 -0800 Subject: Check subscript/property accessor for differentiability. (#5922) --- source/slang/slang-check-decl.cpp | 44 ++++++++++++++++++++++++--------------- source/slang/slang-syntax.cpp | 14 ++++++++++++- tests/autodiff/subscript.slang | 33 +++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 18 deletions(-) create mode 100644 tests/autodiff/subscript.slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e2af70fa9..31e7b70bf 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -143,6 +143,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, void visitAssocTypeDecl(AssocTypeDecl* decl); + void checkDifferentiableCallableCommon(CallableDecl* decl); + void checkCallableDeclCommon(CallableDecl* decl); void visitFuncDecl(FuncDecl* funcDecl); @@ -9109,24 +9111,8 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( } } -void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* decl) { - for (auto paramDecl : decl->getParameters()) - { - ensureDecl(paramDecl, DeclCheckState::ReadyForReference); - } - - auto errorType = decl->errorType; - if (errorType.exp) - { - errorType = CheckProperType(errorType); - } - else - { - errorType = TypeExp(m_astBuilder->getBottomType()); - } - decl->errorType = errorType; - if (auto interfaceDecl = findParentInterfaceDecl(decl)) { bool isDiffFunc = false; @@ -9248,6 +9234,27 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) } } } +} + +void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +{ + for (auto paramDecl : decl->getParameters()) + { + ensureDecl(paramDecl, DeclCheckState::ReadyForReference); + } + + auto errorType = decl->errorType; + if (errorType.exp) + { + errorType = CheckProperType(errorType); + } + else + { + errorType = TypeExp(m_astBuilder->getBottomType()); + } + decl->errorType = errorType; + + checkDifferentiableCallableCommon(decl); // If this method is intended to be a CUDA kernel, verify that the return type is void. if (decl->findModifier()) @@ -9709,6 +9716,8 @@ void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl) // for `GetterDecl`s. // decl->returnType.type = _getAccessorStorageType(decl); + + checkDifferentiableCallableCommon(decl); } void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) @@ -9799,6 +9808,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) newValueType); } } + checkDifferentiableCallableCommon(decl); } GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl) diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 9e448a5fa..1d3763299 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -851,7 +851,7 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef const& declR List paramTypes; auto resultType = getResultType(astBuilder, declRef); auto errorType = getErrorCodeType(astBuilder, declRef); - for (auto paramDeclRef : getParameters(astBuilder, declRef)) + auto visitParamDecl = [&](DeclRef paramDeclRef) { auto paramDecl = paramDeclRef.getDecl(); auto paramType = getParamType(astBuilder, paramDeclRef); @@ -875,6 +875,18 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef const& declR } } paramTypes.add(paramType); + }; + auto parent = declRef.getParent(); + if (as(parent) || as(parent)) + { + for (auto paramDeclRef : getParameters(astBuilder, parent.as())) + { + visitParamDecl(paramDeclRef); + } + } + for (auto paramDeclRef : getParameters(astBuilder, declRef)) + { + visitParamDecl(paramDeclRef); } FuncType* funcType = diff --git a/tests/autodiff/subscript.slang b/tests/autodiff/subscript.slang new file mode 100644 index 000000000..2b16597c0 --- /dev/null +++ b/tests/autodiff/subscript.slang @@ -0,0 +1,33 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHK): -output-using-type + +interface ITest +{ + __subscript(int i) -> float + { + [BackwardDifferentiable] get; + } +} +struct Test : ITest +{ + __subscript(int i) -> float + { + [BackwardDifferentiable] get { return 5.0f * i; } + } +} + +[Differentiable] +float test(ITest arg) +{ + return arg[1]; +} + +//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer output; + +[numthreads(1,1,1)] +void computeMain() +{ + Test t = {}; + output[0] = test(t); + // CHK: 5.0 +} \ No newline at end of file -- cgit v1.2.3