From 3ff257816fc8f376d9bee76378a690757f8b5377 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:08:03 -0500 Subject: Fix interface requirement lowering for generic accessors (#6123) --- source/slang/slang-lower-to-ir.cpp | 10 ++++++++++ tests/autodiff/generic-accessors.slang | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 tests/autodiff/generic-accessors.slang diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 086345719..54540a3f8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8633,6 +8633,9 @@ struct DeclLoweringVisitor : DeclVisitor UInt operandCount = 0; for (auto requirementDecl : decl->members) { + if (as(requirementDecl)) + requirementDecl = getInner(requirementDecl); + if (as(requirementDecl) || as(requirementDecl)) { for (auto accessorDecl : as(requirementDecl)->members) @@ -8782,6 +8785,13 @@ struct DeclLoweringVisitor : DeclVisitor auto requirementKey = getInterfaceRequirementKey(requirementDecl); if (!requirementKey) { + if (auto genericDecl = as(requirementDecl)) + { + // We need to form a declref into the inner decls in case of a generic + // requirement. + requirementDecl = getInner(genericDecl); + } + if (as(requirementDecl) || as(requirementDecl)) { for (auto member : as(requirementDecl)->members) diff --git a/tests/autodiff/generic-accessors.slang b/tests/autodiff/generic-accessors.slang new file mode 100644 index 000000000..2b179f256 --- /dev/null +++ b/tests/autodiff/generic-accessors.slang @@ -0,0 +1,32 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHK): -output-using-type + +interface ITest +{ + __generic + __subscript(I i) -> float + { + [BackwardDifferentiable] get; + } +} +struct Test : ITest +{ + __generic + __subscript(I i) -> float + { + [BackwardDifferentiable] get { return 5.0f * i.toInt(); } + } +} +[Differentiable] +float test(ITest arg) +{ + return arg[1]; +} +//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer output; +[numthreads(1,1,1)] +void computeMain() +{ + Test t = {}; + output[0] = test(t); + // CHK: 5.0 +} \ No newline at end of file -- cgit v1.2.3