diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 32 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 6 | ||||
| -rw-r--r-- | tests/autodiff/bool-return-val-bwd.slang | 28 | ||||
| -rw-r--r-- | tests/autodiff/bool-return-val-bwd.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/bsdf/bsdf-auto-rev.slang | 11 | ||||
| -rw-r--r-- | tests/autodiff/reverse-conditional-out-assign.slang | 40 | ||||
| -rw-r--r-- | tests/autodiff/reverse-conditional-out-assign.slang.expected.txt | 6 |
10 files changed, 133 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 7782bd39c..b5d3dba10 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1514,7 +1514,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType()); auto diff = builder->emitVar(diffType); - builder->markInstAsDifferential(diff, ptrInnerPairType->getValueType()); + builder->markInstAsDifferential( + diff, builder->getPtrType(ptrInnerPairType->getValueType())); IRInst* primalInitVal = nullptr; IRInst* diffInitVal = nullptr; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 20090ca42..ff8ece76c 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -539,6 +539,10 @@ namespace Slang { builder.emitStore(tempVar, builder.emitLoad(param)); } + else + { + builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType())); + } } for (auto block : func->getBlocks()) @@ -589,6 +593,7 @@ namespace Slang AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); + if (SLANG_SUCCEEDED(result)) { simplifyFunc(func); @@ -824,6 +829,7 @@ namespace Slang moveInstChildren(existingPrimalHeader, primalFuncGeneric); primalFuncGeneric->replaceUsesWith(existingPrimalHeader); primalFuncGeneric->removeAndDeallocate(); + primalFuncGeneric = existingPrimalHeader; } else { @@ -831,7 +837,7 @@ namespace Slang builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); } - initializeLocalVariables(builder->getSharedBuilder(), primalFunc); + initializeLocalVariables(builder->getSharedBuilder(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc); } @@ -957,7 +963,6 @@ namespace Slang // after transposition. auto tempVar = nextBlockBuilder.emitVar(diffType); copyNameHintDecoration(tempVar, fwdParam); - nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType); // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. @@ -1088,7 +1093,6 @@ namespace Slang auto diffVar = nextBlockBuilder.emitVar(diffType); copyNameHintDecoration(diffVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(diffVar); - diffBuilder.markInstAsDifferential(diffVar, diffPairType); diffRefReplacement = diffVar; // Clear the diff read var to zero at start of the function. diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 2953c6206..4e1532153 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -506,6 +506,7 @@ struct DiffTransposePass List<IRBlock*> traverseWorkList; HashSet<IRBlock*> traverseSet; traverseWorkList.add(revDiffFunc->getFirstBlock()); + traverseSet.Add(revDiffFunc->getFirstBlock()); for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) { @@ -517,9 +518,13 @@ struct DiffTransposePass // or entirely with differential insts. continue; } + workList.add(block); } + if (!workList.getCount()) + return; + // Reverse the order of the blocks. workList.reverse(); @@ -533,7 +538,32 @@ struct DiffTransposePass // Keep track of first diff block, since this is where // we'll emit temporary vars to hold per-block derivatives. // - firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]]; + auto firstRevDiffBlock = revBlockMap[terminalDiffBlocks[0]].GetValue(); + firstRevDiffBlockMap[revDiffFunc] = firstRevDiffBlock; + + // Move all diff vars to first block, and initialize them with zero. + builder.setInsertInto(firstRevDiffBlock); + for (auto block : workList) + { + for (auto inst = block->getFirstInst(); inst;) + { + auto nextInst = inst->getNextInst(); + if (auto varInst = as<IRVar>(inst)) + { + if (auto diffDecor = varInst->findDecoration<IRDifferentialInstDecoration>()) + { + if (auto ptrPrimalType = as<IRPtrTypeBase>(diffDecor->getPrimalType())) + { + varInst->insertAtEnd(firstRevDiffBlock); + + auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType()); + builder.emitStore(varInst, dzero); + } + } + } + inst = nextInst; + } + } for (auto block : workList) { diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 1a85ea6a4..d83ff57e4 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -839,7 +839,7 @@ struct DiffUnzipPass auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType); auto primalVar = primalBuilder->emitVar(primalType); auto diffVar = diffBuilder->emitVar(diffType); - diffBuilder->markInstAsDifferential(diffVar, primalType); + diffBuilder->markInstAsDifferential(diffVar, diffBuilder->getPtrType(primalType)); return InstPair(primalVar, diffVar); } @@ -874,7 +874,7 @@ struct DiffUnzipPass // If return value is not differentiable, just turn it into a trivial branch. auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); primalBuilder->addBackwardDerivativePrimalReturnDecoration( - primalBranch, primalBuilder->getVoidValue()); + primalBranch, mixedReturn->getVal()); auto returnInst = diffBuilder->emitReturn(); diffBuilder->markInstAsDifferential(returnInst, nullptr); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 4dbe6d2cb..87c31ffb7 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -427,7 +427,7 @@ struct PeepholeContext : InstPassBase else break; if (i == (IRIntegerValue)constIndex->getValue()) - arg = inst->getOperand(2); + arg = updateInst->getElementValue(); args.add(arg); } if (args.getCount() == arraySize->getValue()) @@ -456,8 +456,8 @@ struct PeepholeContext : InstPassBase IRInst* arg = nullptr; if (i < oldVal->getOperandCount()) arg = oldVal->getOperand(i); - if (field->getKey() == inst->getOperand(1)) - arg = inst->getOperand(2); + if (field->getKey() == key) + arg = updateInst->getElementValue(); if (arg) { args.add(arg); diff --git a/tests/autodiff/bool-return-val-bwd.slang b/tests/autodiff/bool-return-val-bwd.slang new file mode 100644 index 000000000..40eb4810a --- /dev/null +++ b/tests/autodiff/bool-return-val-bwd.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct NonDiff +{ + float a; +} + +[BackwardDifferentiable] +bool myFunc(NonDiff fIn, inout float x) +{ + x = pow(x, fIn.a); + return x > 100.f; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float a = 3.0; + NonDiff fIn = { a }; + DifferentialPair<float> dpx = DifferentialPair<float>(4.f, 1.f); + __bwd_diff(myFunc)(fIn, dpx); + + outputBuffer[0] = dpx.d; +}
\ No newline at end of file diff --git a/tests/autodiff/bool-return-val-bwd.slang.expected.txt b/tests/autodiff/bool-return-val-bwd.slang.expected.txt new file mode 100644 index 000000000..255e4b2f4 --- /dev/null +++ b/tests/autodiff/bool-return-val-bwd.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +48.000000 +0.000000 +0.000000 +0.000000
\ No newline at end of file diff --git a/tests/autodiff/bsdf/bsdf-auto-rev.slang b/tests/autodiff/bsdf/bsdf-auto-rev.slang index c2ba434f8..7fae5e993 100644 --- a/tests/autodiff/bsdf/bsdf-auto-rev.slang +++ b/tests/autodiff/bsdf/bsdf-auto-rev.slang @@ -32,11 +32,16 @@ struct Auto_Bwd_BSDFParameters : IDifferentiable }; [BackwardDifferentiable] -void bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Auto_Bwd_ScatterSample result) +bool bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Auto_Bwd_ScatterSample result) { float3 wiLocal = no_diff(sd.toLocal(sd.V)); float2 u = float2(0.8, 0.3); + if (wiLocal.z < 1e-6) + { + return false; + } + // Taken from Rendering.Materials.Microfacet. Follows the Walter et al. EGSR07 BTDF paper float alphaSqr = params.roughness * params.roughness; float phi = u.y * (2 * 3.1415926); @@ -52,6 +57,8 @@ void bsdfGGXSample(in ShadingData sd, in Auto_Bwd_BSDFParameters params, out Aut result.wo = no_diff(sd.fromLocal(woLocal)); // wo to world. result.pdf = detach(pdf); result.weight = evalGGXDivByPDF(wiLocal, woLocal, hLocal, params) * pdf / detach(pdf); + + return woLocal.z > 1e-6; } [BackwardDifferentiable] @@ -95,4 +102,4 @@ float bsdfGGXPDF(in float3 hLocal, in Auto_Bwd_BSDFParameters params) float d = ((cosTheta * a2 - cosTheta) * cosTheta + 1); return (a2 / (d * d * 3.1415926)) * cosTheta; -} +}
\ No newline at end of file diff --git a/tests/autodiff/reverse-conditional-out-assign.slang b/tests/autodiff/reverse-conditional-out-assign.slang new file mode 100644 index 000000000..4f6b105f1 --- /dev/null +++ b/tests/autodiff/reverse-conditional-out-assign.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +bool test_single_branch(float y, out float o) +{ + if (y > 0.5) + { + o = y * 2.0f; + return true; + } + + o = y + 1.0f; + + return false; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_single_branch)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 2.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_single_branch)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 1.0 + } +} diff --git a/tests/autodiff/reverse-conditional-out-assign.slang.expected.txt b/tests/autodiff/reverse-conditional-out-assign.slang.expected.txt new file mode 100644 index 000000000..86aa47f11 --- /dev/null +++ b/tests/autodiff/reverse-conditional-out-assign.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +1.000000 +0.000000 +0.000000 +0.000000 |
