summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-09 17:19:55 -0500
committerGitHub <noreply@github.com>2023-02-09 17:19:55 -0500
commitd911e1bed9572664b1d0554feb3c7d1a2a880518 (patch)
tree1d34d8a641c83759e44f74d64364e3bfd27d0416
parentfbe31ada800b3417d10a24f6c0481d3cb6b161e4 (diff)
Fixed derivatives for kIROp_Neg and kIROp_Div, added another test (#2639)
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h37
-rw-r--r--tests/autodiff/reverse-arithmetic.slang78
-rw-r--r--tests/autodiff/reverse-arithmetic.slang.expected.txt8
4 files changed, 134 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index a45a3abf9..04acad435 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1329,9 +1329,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
case kIROp_MakeTuple:
+ case kIROp_Neg:
return transcribeByPassthrough(builder, origInst);
+
case kIROp_UpdateElement:
return transcribeUpdateElement(builder, origInst);
+
case kIROp_unconditionalBranch:
return transcribeControlFlow(builder, origInst);
@@ -1346,6 +1349,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_FieldExtract:
case kIROp_FieldAddress:
return transcribeFieldExtract(builder, origInst);
+
case kIROp_GetElement:
case kIROp_GetElementPtr:
return transcribeGetElement(builder, origInst);
@@ -1361,12 +1365,15 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeDifferentialPair:
return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
+
case kIROp_DifferentialPairGetPrimal:
case kIROp_DifferentialPairGetDifferential:
return transcribeDifferentialPairGetElement(builder, origInst);
+
case kIROp_ExtractExistentialValue:
case kIROp_MakeExistential:
return transcribeSingleOperandInst(builder, origInst);
+
case kIROp_ExtractExistentialType:
{
IRInst* witnessTable;
@@ -1377,8 +1384,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
}
case kIROp_ExtractExistentialWitnessTable:
return transcribeExtractExistentialWitnessTable(builder, origInst);
+
case kIROp_WrapExistential:
return transcribeWrapExistential(builder, origInst);
+
case kIROp_undefined:
return transcribeUndefined(builder, origInst);
@@ -1387,8 +1396,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
// so we treat this inst as non differentiable.
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
return trascribeNonDiffInst(builder, origInst);
+
case kIROp_StructKey:
return InstPair(origInst, nullptr);
+
case kIROp_Unreachable:
{
auto unreachInst = builder->emitUnreachable();
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index eef820804..d9b28ea3c 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1247,6 +1247,8 @@ struct DiffTransposePass
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
+ case kIROp_Div:
+ case kIROp_Neg:
return transposeArithmetic(builder, fwdInst, revValue);
case kIROp_Call:
@@ -1927,6 +1929,41 @@ struct DiffTransposePass
SLANG_ASSERT_FAILURE("Neither operand of a mul instruction is a differential inst");
}
}
+ case kIROp_Div:
+ {
+ if (isDifferentialInst(fwdInst->getOperand(0)))
+ {
+ SLANG_RELEASE_ASSERT(!isDifferentialInst(fwdInst->getOperand(1)));
+
+ // (Out = dA / B) -> (dA += dOut / B)
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ fwdInst->getOperand(0),
+ builder->emitDiv(operandType, revValue, fwdInst->getOperand(1)),
+ fwdInst)));
+ }
+ {
+ SLANG_ASSERT_FAILURE("The first operand of a div inst must be a differential inst");
+ }
+ }
+ case kIROp_Neg:
+ {
+ if (isDifferentialInst(fwdInst->getOperand(0)))
+ {
+ // (Out = -dA) -> (dA += -dOut)
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ fwdInst->getOperand(0),
+ builder->emitNeg(operandType, revValue),
+ fwdInst)));
+ }
+ else
+ {
+ SLANG_ASSERT_FAILURE("Cannot transpose neg of a non-differentiable inst");
+ }
+ }
default:
SLANG_ASSERT_FAILURE("Unhandled arithmetic");
diff --git a/tests/autodiff/reverse-arithmetic.slang b/tests/autodiff/reverse-arithmetic.slang
new file mode 100644
index 000000000..4e13251e4
--- /dev/null
+++ b/tests/autodiff/reverse-arithmetic.slang
@@ -0,0 +1,78 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float f(float x)
+{
+ return x;
+}
+
+[BackwardDifferentiable]
+float h(float x, float y)
+{
+ float m = x + y;
+ float n = x - y;
+ return m * n + 2 * x * y;
+}
+
+[BackwardDifferentiable]
+float j(float x, float y)
+{
+ float m = x / y;
+ return m * y;
+}
+
+[BackwardDifferentiable]
+float k(float x, float y)
+{
+ float m = -x;
+ return m * y;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+
+ __bwd_diff(f)(dpa, 1.0);
+
+ outputBuffer[0] = dpa.d; // Expect: 1
+ }
+
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat dpb = dpfloat(1.5, 1.0);
+
+ __bwd_diff(h)(dpa, dpb, 1.0);
+
+ outputBuffer[1] = dpa.d; // Expect: (2 * 2.0) + (2 * 1.5) = 7.0
+ outputBuffer[2] = dpb.d; // Expect: -(2 * 1.5) + (2 * 2.0) = 1.0
+ }
+
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat dpb = dpfloat(1.5, 1.0);
+
+ __bwd_diff(j)(dpa, dpb, 1.0);
+
+ outputBuffer[3] = dpa.d; // Expect: 1
+ outputBuffer[4] = dpb.d; // Expect: 0
+ }
+
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat dpb = dpfloat(1.5, 1.0);
+
+ __bwd_diff(k)(dpa, dpb, 1.0);
+
+ outputBuffer[5] = dpa.d; // Expect: -1.5
+ outputBuffer[6] = dpb.d; // Expect: -2.0
+ }
+}
diff --git a/tests/autodiff/reverse-arithmetic.slang.expected.txt b/tests/autodiff/reverse-arithmetic.slang.expected.txt
new file mode 100644
index 000000000..297192e6c
--- /dev/null
+++ b/tests/autodiff/reverse-arithmetic.slang.expected.txt
@@ -0,0 +1,8 @@
+type: float
+1.000000
+7.000000
+1.000000
+1.000000
+0.000000
+-1.500000
+-2.000000