summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h32
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h4
-rw-r--r--source/slang/slang-ir-peephole.cpp6
-rw-r--r--tests/autodiff/bool-return-val-bwd.slang28
-rw-r--r--tests/autodiff/bool-return-val-bwd.slang.expected.txt5
-rw-r--r--tests/autodiff/bsdf/bsdf-auto-rev.slang11
-rw-r--r--tests/autodiff/reverse-conditional-out-assign.slang40
-rw-r--r--tests/autodiff/reverse-conditional-out-assign.slang.expected.txt6
10 files changed, 133 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 7782bd39c..b5d3dba10 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1514,7 +1514,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType());
auto diff = builder->emitVar(diffType);
- builder->markInstAsDifferential(diff, ptrInnerPairType->getValueType());
+ builder->markInstAsDifferential(
+ diff, builder->getPtrType(ptrInnerPairType->getValueType()));
IRInst* primalInitVal = nullptr;
IRInst* diffInitVal = nullptr;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 20090ca42..ff8ece76c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -539,6 +539,10 @@ namespace Slang
{
builder.emitStore(tempVar, builder.emitLoad(param));
}
+ else
+ {
+ builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType()));
+ }
}
for (auto block : func->getBlocks())
@@ -589,6 +593,7 @@ namespace Slang
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
+
if (SLANG_SUCCEEDED(result))
{
simplifyFunc(func);
@@ -824,6 +829,7 @@ namespace Slang
moveInstChildren(existingPrimalHeader, primalFuncGeneric);
primalFuncGeneric->replaceUsesWith(existingPrimalHeader);
primalFuncGeneric->removeAndDeallocate();
+ primalFuncGeneric = existingPrimalHeader;
}
else
{
@@ -831,7 +837,7 @@ namespace Slang
builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
}
- initializeLocalVariables(builder->getSharedBuilder(), primalFunc);
+ initializeLocalVariables(builder->getSharedBuilder(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc);
}
@@ -957,7 +963,6 @@ namespace Slang
// after transposition.
auto tempVar = nextBlockBuilder.emitVar(diffType);
copyNameHintDecoration(tempVar, fwdParam);
- nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType);
// Initialize the var with input diff param at start.
// Note that we insert the store in the primal block so it won't get transposed.
@@ -1088,7 +1093,6 @@ namespace Slang
auto diffVar = nextBlockBuilder.emitVar(diffType);
copyNameHintDecoration(diffVar, fwdParam);
result.propagateFuncSpecificPrimalInsts.add(diffVar);
- diffBuilder.markInstAsDifferential(diffVar, diffPairType);
diffRefReplacement = diffVar;
// Clear the diff read var to zero at start of the function.
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 2953c6206..4e1532153 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -506,6 +506,7 @@ struct DiffTransposePass
List<IRBlock*> traverseWorkList;
HashSet<IRBlock*> traverseSet;
traverseWorkList.add(revDiffFunc->getFirstBlock());
+
traverseSet.Add(revDiffFunc->getFirstBlock());
for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
{
@@ -517,9 +518,13 @@ struct DiffTransposePass
// or entirely with differential insts.
continue;
}
+
workList.add(block);
}
+ if (!workList.getCount())
+ return;
+
// Reverse the order of the blocks.
workList.reverse();
@@ -533,7 +538,32 @@ struct DiffTransposePass
// Keep track of first diff block, since this is where
// we'll emit temporary vars to hold per-block derivatives.
//
- firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]];
+ auto firstRevDiffBlock = revBlockMap[terminalDiffBlocks[0]].GetValue();
+ firstRevDiffBlockMap[revDiffFunc] = firstRevDiffBlock;
+
+ // Move all diff vars to first block, and initialize them with zero.
+ builder.setInsertInto(firstRevDiffBlock);
+ for (auto block : workList)
+ {
+ for (auto inst = block->getFirstInst(); inst;)
+ {
+ auto nextInst = inst->getNextInst();
+ if (auto varInst = as<IRVar>(inst))
+ {
+ if (auto diffDecor = varInst->findDecoration<IRDifferentialInstDecoration>())
+ {
+ if (auto ptrPrimalType = as<IRPtrTypeBase>(diffDecor->getPrimalType()))
+ {
+ varInst->insertAtEnd(firstRevDiffBlock);
+
+ auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType());
+ builder.emitStore(varInst, dzero);
+ }
+ }
+ }
+ inst = nextInst;
+ }
+ }
for (auto block : workList)
{
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 1a85ea6a4..d83ff57e4 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -839,7 +839,7 @@ struct DiffUnzipPass
auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType);
auto primalVar = primalBuilder->emitVar(primalType);
auto diffVar = diffBuilder->emitVar(diffType);
- diffBuilder->markInstAsDifferential(diffVar, primalType);
+ diffBuilder->markInstAsDifferential(diffVar, diffBuilder->getPtrType(primalType));
return InstPair(primalVar, diffVar);
}
@@ -874,7 +874,7 @@ struct DiffUnzipPass
// If return value is not differentiable, just turn it into a trivial branch.
auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
primalBuilder->addBackwardDerivativePrimalReturnDecoration(
- primalBranch, primalBuilder->getVoidValue());
+ primalBranch, mixedReturn->getVal());
auto returnInst = diffBuilder->emitReturn();
diffBuilder->markInstAsDifferential(returnInst, nullptr);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 4dbe6d2cb..87c31ffb7 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -427,7 +427,7 @@ struct PeepholeContext : InstPassBase
else
break;
if (i == (IRIntegerValue)constIndex->getValue())
- arg = inst->getOperand(2);
+ arg = updateInst->getElementValue();
args.add(arg);
}
if (args.getCount() == arraySize->getValue())
@@ -456,8 +456,8 @@ struct PeepholeContext : InstPassBase
IRInst* arg = nullptr;
if (i < oldVal->getOperandCount())
arg = oldVal->getOperand(i);
- if (field->getKey() == inst->getOperand(1))
- arg = inst->getOperand(2);
+ if (field->getKey() == key)
+ arg = updateInst->getElementValue();
if (arg)
{
args.add(arg);
diff --git a/tests/autodiff/bool-return-val-bwd.slang b/tests/autodiff/bool-return-val-bwd.slang
new file mode 100644
index 000000000..40eb4810a
--- /dev/null
+++ b/tests/autodiff/bool-return-val-bwd.slang
@@ -0,0 +1,28 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct NonDiff
+{
+ float a;
+}
+
+[BackwardDifferentiable]
+bool myFunc(NonDiff fIn, inout float x)
+{
+ x = pow(x, fIn.a);
+ return x > 100.f;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float a = 3.0;
+ NonDiff fIn = { a };
+ DifferentialPair<float> dpx = DifferentialPair<float>(4.f, 1.f);
+ __bwd_diff(myFunc)(fIn, dpx);
+
+ outputBuffer[0] = dpx.d;
+} \ No newline at end of file
diff --git a/tests/autodiff/bool-return-val-bwd.slang.expected.txt b/tests/autodiff/bool-return-val-bwd.slang.expected.txt
new file mode 100644
index 000000000..255e4b2f4
--- /dev/null
+++ b/tests/autodiff/bool-return-val-bwd.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+48.000000
+0.000000
+0.000000
+0.000000 \ No newline at end of file
diff --git a/tests/autodiff/bsdf/bsdf-auto-rev.slang b/tests/autodiff/bsdf/bsdf-auto-rev.slang
index c2ba434f8..7fae5e993 100644
--- a/tests/autodiff/bsdf/bsdf-auto-rev.slang
+++ b/tests/autodiff/bsdf/bsdf-auto-rev.slang
@@ -32,11 +32,16 @@ struct Auto_Bwd_BSDFParameters : IDifferentiable
};
[BackwardDifferentiable]
-void bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Auto_Bwd_ScatterSample result)
+bool bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Auto_Bwd_ScatterSample result)
{
float3 wiLocal = no_diff(sd.toLocal(sd.V));
float2 u = float2(0.8, 0.3);
+ if (wiLocal.z < 1e-6)
+ {
+ return false;
+ }
+
// Taken from Rendering.Materials.Microfacet. Follows the Walter et al. EGSR07 BTDF paper
float alphaSqr = params.roughness * params.roughness;
float phi = u.y * (2 * 3.1415926);
@@ -52,6 +57,8 @@ void bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Aut
result.wo = no_diff(sd.fromLocal(woLocal)); // wo to world.
result.pdf = detach(pdf);
result.weight = evalGGXDivByPDF(wiLocal, woLocal, hLocal, params) * pdf / detach(pdf);
+
+ return woLocal.z > 1e-6;
}
[BackwardDifferentiable]
@@ -95,4 +102,4 @@ float bsdfGGXPDF(in float3 hLocal, in Auto_Bwd_BSDFParameters params)
float d = ((cosTheta * a2 - cosTheta) * cosTheta + 1);
return (a2 / (d * d * 3.1415926)) * cosTheta;
-}
+} \ No newline at end of file
diff --git a/tests/autodiff/reverse-conditional-out-assign.slang b/tests/autodiff/reverse-conditional-out-assign.slang
new file mode 100644
index 000000000..4f6b105f1
--- /dev/null
+++ b/tests/autodiff/reverse-conditional-out-assign.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+bool test_single_branch(float y, out float o)
+{
+ if (y > 0.5)
+ {
+ o = y * 2.0f;
+ return true;
+ }
+
+ o = y + 1.0f;
+
+ return false;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ __bwd_diff(test_single_branch)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 2.0
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_single_branch)(dpa, 1.0f);
+ outputBuffer[1] = dpa.d; // Expect: 1.0
+ }
+}
diff --git a/tests/autodiff/reverse-conditional-out-assign.slang.expected.txt b/tests/autodiff/reverse-conditional-out-assign.slang.expected.txt
new file mode 100644
index 000000000..86aa47f11
--- /dev/null
+++ b/tests/autodiff/reverse-conditional-out-assign.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+2.000000
+1.000000
+0.000000
+0.000000
+0.000000