diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-20 00:53:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-20 00:53:49 -0800 |
| commit | e93cb8a4d1bb7d835bc3762ce25ced422e75e97a (patch) | |
| tree | 47b482d99fdbb12f8bc3940d7cbf39444c47c671 | |
| parent | 5c9f011fa4948d1f70689b03ddcd203cb2525b2b (diff) | |
Check subscript/property accessor for differentiability. (#5922)
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 14 | ||||
| -rw-r--r-- | tests/autodiff/subscript.slang | 33 |
3 files changed, 73 insertions, 18 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e2af70fa9..31e7b70bf 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -143,6 +143,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, void visitAssocTypeDecl(AssocTypeDecl* decl); + void checkDifferentiableCallableCommon(CallableDecl* decl); + void checkCallableDeclCommon(CallableDecl* decl); void visitFuncDecl(FuncDecl* funcDecl); @@ -9109,24 +9111,8 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( } } -void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* decl) { - for (auto paramDecl : decl->getParameters()) - { - ensureDecl(paramDecl, DeclCheckState::ReadyForReference); - } - - auto errorType = decl->errorType; - if (errorType.exp) - { - errorType = CheckProperType(errorType); - } - else - { - errorType = TypeExp(m_astBuilder->getBottomType()); - } - decl->errorType = errorType; - if (auto interfaceDecl = findParentInterfaceDecl(decl)) { bool isDiffFunc = false; @@ -9248,6 +9234,27 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) } } } +} + +void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +{ + for (auto paramDecl : decl->getParameters()) + { + ensureDecl(paramDecl, DeclCheckState::ReadyForReference); + } + + auto errorType = decl->errorType; + if (errorType.exp) + { + errorType = CheckProperType(errorType); + } + else + { + errorType = TypeExp(m_astBuilder->getBottomType()); + } + decl->errorType = errorType; + + checkDifferentiableCallableCommon(decl); // If this method is intended to be a CUDA kernel, verify that the return type is void. if (decl->findModifier<CudaKernelAttribute>()) @@ -9709,6 +9716,8 @@ void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl) // for `GetterDecl`s. // decl->returnType.type = _getAccessorStorageType(decl); + + checkDifferentiableCallableCommon(decl); } void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) @@ -9799,6 +9808,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) newValueType); } } + checkDifferentiableCallableCommon(decl); } GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl) diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 9e448a5fa..1d3763299 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -851,7 +851,7 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR List<Type*> paramTypes; auto resultType = getResultType(astBuilder, declRef); auto errorType = getErrorCodeType(astBuilder, declRef); - for (auto paramDeclRef : getParameters(astBuilder, declRef)) + auto visitParamDecl = [&](DeclRef<ParamDecl> paramDeclRef) { auto paramDecl = paramDeclRef.getDecl(); auto paramType = getParamType(astBuilder, paramDeclRef); @@ -875,6 +875,18 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR } } paramTypes.add(paramType); + }; + auto parent = declRef.getParent(); + if (as<SubscriptDecl>(parent) || as<PropertyDecl>(parent)) + { + for (auto paramDeclRef : getParameters(astBuilder, parent.as<CallableDecl>())) + { + visitParamDecl(paramDeclRef); + } + } + for (auto paramDeclRef : getParameters(astBuilder, declRef)) + { + visitParamDecl(paramDeclRef); } FuncType* funcType = diff --git a/tests/autodiff/subscript.slang b/tests/autodiff/subscript.slang new file mode 100644 index 000000000..2b16597c0 --- /dev/null +++ b/tests/autodiff/subscript.slang @@ -0,0 +1,33 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHK): -output-using-type + +interface ITest +{ + __subscript(int i) -> float + { + [BackwardDifferentiable] get; + } +} +struct Test : ITest +{ + __subscript(int i) -> float + { + [BackwardDifferentiable] get { return 5.0f * i; } + } +} + +[Differentiable] +float test(ITest arg) +{ + return arg[1]; +} + +//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<float> output; + +[numthreads(1,1,1)] +void computeMain() +{ + Test t = {}; + output[0] = test(t); + // CHK: 5.0 +}
\ No newline at end of file |
