From bbd9c2e6d7b57f5acc3238083ab2f7c7b140df5e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 12 Jul 2023 18:02:36 -0400 Subject: Extend `no_diff` to support subscript operations on resources and array variables… (#2981) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Extend `no_diff` to support subscript operations on resources and array variables * Update autodiff.slang.expected --- source/slang/slang-check-expr.cpp | 2 +- source/slang/slang-diagnostic-defs.h | 2 +- source/slang/slang-ir-insts.h | 1 + source/slang/slang-ir.cpp | 11 ++++++ source/slang/slang-lower-to-ir.cpp | 42 ++++++++++++++++++++-- tests/autodiff/no-diff-array-access.slang | 23 ++++++++++++ .../no-diff-array-access.slang.expected.txt | 6 ++++ tests/diagnostics/autodiff.slang.expected | 6 ++-- 8 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 tests/autodiff/no-diff-array-access.slang create mode 100644 tests/autodiff/no-diff-array-access.slang.expected.txt diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 6b050aa89..9af8fd867 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2712,7 +2712,7 @@ namespace Slang { innerExpr = parenExpr->base; } - if (!as(innerExpr)) + if (!as(innerExpr) && !as(innerExpr)) { getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff); } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 1832a3b46..463c6f525 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -535,7 +535,7 @@ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argu DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") -DIAGNOSTIC(38031, Error, invalidUseOfNoDiff, "'no_diff' can only be used to decorate a call.") +DIAGNOSTIC(38031, Error, invalidUseOfNoDiff, "'no_diff' can only be used to decorate a call or a subscript operation") DIAGNOSTIC(38032, Error, useOfNoDiffOnDifferentiableFunc, "use 'no_diff' on a call to a differentiable function has no meaning.") DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no_diff' in a non-differentiable function.") diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4eb3982d3..a4563c254 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3069,6 +3069,7 @@ public: IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); + IRInst* emitDetachDerivative(IRType* type, IRInst* value); IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs); IRInst* emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 84730c913..a44667a79 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3236,6 +3236,17 @@ namespace Slang return inst; } + IRInst *IRBuilder::emitDetachDerivative(IRType *type, IRInst *value) + { + auto inst = createInst( + this, + kIROp_DetachDerivative, + type, + value); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn) { auto inst = createInst( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d869bf60e..a1c7a2b8e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3333,9 +3333,45 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) { auto baseVal = lowerSubExpr(expr->innerExpr); - SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); - return baseVal; + + IRInst* innerInst = nullptr; + if (baseVal.flavor != LoweredValInfo::Flavor::Simple) + { + if (!isLValueContext()) + { + auto materializedVal = materialize(context, baseVal); + + // TODO(Sai): We might be missing the case where a single materialize could create + // multiple calls (multiple index operations?). Not quite sure what the right way + // to handle that case might be. + // + if (as(materializedVal.val)) + getBuilder()->addDecoration(materializedVal.val, kIROp_TreatAsDifferentiableDecoration); + + innerInst = getSimpleVal(context, materializedVal); + + // We'll special case handle 'loads' here in order to allow TreatAsDifferentiable to be + // used on array index operations. (This is to avoid a discrepancy between using no_diff + // on local variable indexing vs. resource indexing.) + // + if (as(innerInst)) + innerInst = getBuilder()->emitDetachDerivative(innerInst->getDataType(), innerInst); + } + else + { + SLANG_ASSERT("TreatAsDifferentiableExpr on non-simple l-values not properly defined."); + } + } + else + { + if (as(baseVal.val)) + getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); + innerInst = baseVal.val; + } + + SLANG_ASSERT(innerInst); + + return LoweredValInfo::simple(innerInst); } // Emit IR to denote the forward-mode derivative diff --git a/tests/autodiff/no-diff-array-access.slang b/tests/autodiff/no-diff-array-access.slang new file mode 100644 index 000000000..df8c8faa0 --- /dev/null +++ b/tests/autodiff/no-diff-array-access.slang @@ -0,0 +1,23 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; +typedef DifferentialPair dpfloat3; + +[ForwardDifferentiable] +float f(float x, float[3] y) +{ + return x * no_diff(outputBuffer[4]) + y[2] * x * no_diff(y[1]); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var dpx = fwd_diff(f)(DifferentialPair(1.0f, 1.0f), DifferentialPair( { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f })); + outputBuffer[0] = dpx.p; // Expect: 6.0 + outputBuffer[1] = dpx.d; // Expect: 12.0 +} diff --git a/tests/autodiff/no-diff-array-access.slang.expected.txt b/tests/autodiff/no-diff-array-access.slang.expected.txt new file mode 100644 index 000000000..504eaec3e --- /dev/null +++ b/tests/autodiff/no-diff-array-access.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +6.000000 +12.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/diagnostics/autodiff.slang.expected b/tests/diagnostics/autodiff.slang.expected index d075b9406..95da2a6d3 100644 --- a/tests/diagnostics/autodiff.slang.expected +++ b/tests/diagnostics/autodiff.slang.expected @@ -1,12 +1,12 @@ result code = -1 standard error = { -tests/diagnostics/autodiff.slang(35): error 38031: 'no_diff' can only be used to decorate a call. +tests/diagnostics/autodiff.slang(30): error 38031: 'no_diff' can only be used to decorate a call or a subscript operation float x1 = no_diff x; // invalid use of no_diff here. ^~~~~~~ -tests/diagnostics/autodiff.slang(36): error 38032: use 'no_diff' on a call to a differentiable function has no meaning. +tests/diagnostics/autodiff.slang(31): error 38032: use 'no_diff' on a call to a differentiable function has no meaning. return no_diff f(x); // no_diff on a differentiable call has no meaning. ^~~~~~~ -tests/diagnostics/autodiff.slang(41): error 38033: cannot use 'no_diff' in a non-differentiable function. +tests/diagnostics/autodiff.slang(36): error 38033: cannot use 'no_diff' in a non-differentiable function. return no_diff nonDiff(x); // no_diff in a non-differentiable function ^~~~~~~ } -- cgit v1.2.3