summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp44
-rw-r--r--source/slang/slang-syntax.cpp14
-rw-r--r--tests/autodiff/subscript.slang33
3 files changed, 73 insertions, 18 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index e2af70fa9..31e7b70bf 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -143,6 +143,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase,
void visitAssocTypeDecl(AssocTypeDecl* decl);
+ void checkDifferentiableCallableCommon(CallableDecl* decl);
+
void checkCallableDeclCommon(CallableDecl* decl);
void visitFuncDecl(FuncDecl* funcDecl);
@@ -9109,24 +9111,8 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(
}
}
-void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
+void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* decl)
{
- for (auto paramDecl : decl->getParameters())
- {
- ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
- }
-
- auto errorType = decl->errorType;
- if (errorType.exp)
- {
- errorType = CheckProperType(errorType);
- }
- else
- {
- errorType = TypeExp(m_astBuilder->getBottomType());
- }
- decl->errorType = errorType;
-
if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
bool isDiffFunc = false;
@@ -9248,6 +9234,27 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
}
}
}
+}
+
+void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
+{
+ for (auto paramDecl : decl->getParameters())
+ {
+ ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
+ }
+
+ auto errorType = decl->errorType;
+ if (errorType.exp)
+ {
+ errorType = CheckProperType(errorType);
+ }
+ else
+ {
+ errorType = TypeExp(m_astBuilder->getBottomType());
+ }
+ decl->errorType = errorType;
+
+ checkDifferentiableCallableCommon(decl);
// If this method is intended to be a CUDA kernel, verify that the return type is void.
if (decl->findModifier<CudaKernelAttribute>())
@@ -9709,6 +9716,8 @@ void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl)
// for `GetterDecl`s.
//
decl->returnType.type = _getAccessorStorageType(decl);
+
+ checkDifferentiableCallableCommon(decl);
}
void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
@@ -9799,6 +9808,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
newValueType);
}
}
+ checkDifferentiableCallableCommon(decl);
}
GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl)
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 9e448a5fa..1d3763299 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -851,7 +851,7 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR
List<Type*> paramTypes;
auto resultType = getResultType(astBuilder, declRef);
auto errorType = getErrorCodeType(astBuilder, declRef);
- for (auto paramDeclRef : getParameters(astBuilder, declRef))
+ auto visitParamDecl = [&](DeclRef<ParamDecl> paramDeclRef)
{
auto paramDecl = paramDeclRef.getDecl();
auto paramType = getParamType(astBuilder, paramDeclRef);
@@ -875,6 +875,18 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR
}
}
paramTypes.add(paramType);
+ };
+ auto parent = declRef.getParent();
+ if (as<SubscriptDecl>(parent) || as<PropertyDecl>(parent))
+ {
+ for (auto paramDeclRef : getParameters(astBuilder, parent.as<CallableDecl>()))
+ {
+ visitParamDecl(paramDeclRef);
+ }
+ }
+ for (auto paramDeclRef : getParameters(astBuilder, declRef))
+ {
+ visitParamDecl(paramDeclRef);
}
FuncType* funcType =
diff --git a/tests/autodiff/subscript.slang b/tests/autodiff/subscript.slang
new file mode 100644
index 000000000..2b16597c0
--- /dev/null
+++ b/tests/autodiff/subscript.slang
@@ -0,0 +1,33 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHK): -output-using-type
+
+interface ITest
+{
+ __subscript(int i) -> float
+ {
+ [BackwardDifferentiable] get;
+ }
+}
+struct Test : ITest
+{
+ __subscript(int i) -> float
+ {
+ [BackwardDifferentiable] get { return 5.0f * i; }
+ }
+}
+
+[Differentiable]
+float test(ITest arg)
+{
+ return arg[1];
+}
+
+//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float> output;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ Test t = {};
+ output[0] = test(t);
+ // CHK: 5.0
+} \ No newline at end of file