diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-08-15 00:47:43 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-14 21:47:43 -0700 |
| commit | 113a257aafe4403c3ab905098d0560635ca94286 (patch) | |
| tree | cab658382229c357d59960bbe56d31c60f031dd1 | |
| parent | b05b126e0975f84e3505b2271e06d567e1c13692 (diff) | |
Add auto-diff support for `IRSwizzleStore` (#3102)
* Add auto-diff support for `IRSwizzleStore`
- Lower IRSwizzleStore to multiple stores during AD preprocess.
- Fix typo in `transcribeNonDiffInst`
* Remove unnecessary file & add more robust check for 'local' addresses
* Fix.
* Update slang-ir-autodiff-fwd.cpp
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 62 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 2 | ||||
| -rw-r--r-- | tests/autodiff/swizzled-store.slang | 39 | ||||
| -rw-r--r-- | tests/autodiff/swizzled-store.slang.expected.txt | 5 |
6 files changed, 107 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index efaaec906..25051cb6d 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1617,6 +1617,51 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func) } } +bool isLocalPointer(IRInst* ptrInst) +{ + // If it's not a local var or a function parameter, then it's probably + // referencing something outside the function scope. + // + auto addr = getRootAddr(ptrInst); + return as<IRVar>(addr) || as<IRParam>(addr); +} + +void lowerSwizzledStores(IRModule* module, IRFunc* func) +{ + List<IRInst*> instsToRemove; + + IRBuilder builder(module); + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto swizzledStore = as<IRSwizzledStore>(inst)) + { + if (!isLocalPointer(swizzledStore->getDest())) + continue; + + builder.setInsertBefore(inst); + for (UIndex ii = 0; ii < swizzledStore->getElementCount(); ii++) + { + auto indexVal = swizzledStore->getElementIndex(ii); + auto indexedPtr = builder.emitElementAddress(swizzledStore->getDest(), indexVal); + builder.emitStore( + indexedPtr, + builder.emitElementExtract( + swizzledStore->getSource(), + builder.getIntValue(builder.getIntType(), ii))); + } + instsToRemove.add(inst); + } + } + } + + for (auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } +} + SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) { insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); @@ -1626,6 +1671,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func); + lowerSwizzledStores(autoDiffSharedContext->moduleInst->getModule(), func); + auto result = eliminateAddressInsts(func, sink); if (SLANG_SUCCEEDED(result)) @@ -1846,6 +1893,17 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_undefined: return transcribeUndefined(builder, origInst); + // Differentiable insts that should have been lowered in a previous pass. + case kIROp_SwizzledStore: + { + // If we have a non-null dest ptr, then we error out because something went wrong + // when lowering swizzle-stores to regular stores + // + auto swizzledStore = as<IRSwizzledStore>(origInst); + SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr); + return transcribeNonDiffInst(builder, swizzledStore); + } + // Known non-differentiable insts. case kIROp_Not: case kIROp_BitAnd: @@ -1875,13 +1933,13 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_DetachDerivative: case kIROp_GetSequentialID: case kIROp_GetStringHash: - return trascribeNonDiffInst(builder, origInst); + return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. case kIROp_CreateExistentialObject: - return trascribeNonDiffInst(builder, origInst); + return transcribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3f0036b06..fc8979551 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -218,7 +218,7 @@ namespace Slang case kIROp_WrapExistential: case kIROp_MakeExistential: case kIROp_MakeExistentialWithRTTI: - return trascribeNonDiffInst(builder, origInst); + return transcribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 0d39f879e..0a9ff51a4 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -872,7 +872,7 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc return InstPair(diffBlock, diffBlock); } -InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) +InstPair AutoDiffTranscriberBase::transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst) { auto primal = cloneInst(&cloneEnv, builder, origInst); return InstPair(primal, nullptr); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index d5ad29610..d6b2ea9ff 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -114,7 +114,7 @@ struct AutoDiffTranscriberBase IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); - InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst); + InstPair transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst); InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn); diff --git a/tests/autodiff/swizzled-store.slang b/tests/autodiff/swizzled-store.slang new file mode 100644 index 000000000..58980616b --- /dev/null +++ b/tests/autodiff/swizzled-store.slang @@ -0,0 +1,39 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float2> dpfloat2; +typedef DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float4> dpfloat4; + +[Differentiable] +float2 f(float3 x) +{ + float3 u; + u.zy = x.yx; + return u.zy; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + float3 a = float3(2.0, 2.0, 2.0); + float3 da = float3(1.0, 0.5, 1.0); + + outputBuffer[0] = fwd_diff(f)(dpfloat3(a, da)).d.x; + } + + { + float3 a = float3(2.0, 2.0, 2.0); + var dpa = diffPair(a); + + bwd_diff(f)(dpa, float2(0.5, 1.0)); + + outputBuffer[1] = dpa.d.x; // 1.0 + outputBuffer[2] = dpa.d.y; // 0.5 + outputBuffer[3] = dpa.d.z; // 0.0 + } +} diff --git a/tests/autodiff/swizzled-store.slang.expected.txt b/tests/autodiff/swizzled-store.slang.expected.txt new file mode 100644 index 000000000..8fe60c0db --- /dev/null +++ b/tests/autodiff/swizzled-store.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +0.500000 +1.000000 +0.500000 +0.000000 |
