summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-08-15 00:47:43 -0400
committerGitHub <noreply@github.com>2023-08-14 21:47:43 -0700
commit113a257aafe4403c3ab905098d0560635ca94286 (patch)
treecab658382229c357d59960bbe56d31c60f031dd1
parentb05b126e0975f84e3505b2271e06d567e1c13692 (diff)
Add auto-diff support for `IRSwizzleStore` (#3102)
* Add auto-diff support for `IRSwizzleStore` - Lower IRSwizzleStore to multiple stores during AD preprocess. - Fix typo in `transcribeNonDiffInst` * Remove unnecessary file & add more robust check for 'local' addresses * Fix. * Update slang-ir-autodiff-fwd.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp62
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--tests/autodiff/swizzled-store.slang39
-rw-r--r--tests/autodiff/swizzled-store.slang.expected.txt5
6 files changed, 107 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index efaaec906..25051cb6d 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1617,6 +1617,51 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
}
}
+bool isLocalPointer(IRInst* ptrInst)
+{
+ // If it's not a local var or a function parameter, then it's probably
+ // referencing something outside the function scope.
+ //
+ auto addr = getRootAddr(ptrInst);
+ return as<IRVar>(addr) || as<IRParam>(addr);
+}
+
+void lowerSwizzledStores(IRModule* module, IRFunc* func)
+{
+ List<IRInst*> instsToRemove;
+
+ IRBuilder builder(module);
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (auto swizzledStore = as<IRSwizzledStore>(inst))
+ {
+ if (!isLocalPointer(swizzledStore->getDest()))
+ continue;
+
+ builder.setInsertBefore(inst);
+ for (UIndex ii = 0; ii < swizzledStore->getElementCount(); ii++)
+ {
+ auto indexVal = swizzledStore->getElementIndex(ii);
+ auto indexedPtr = builder.emitElementAddress(swizzledStore->getDest(), indexVal);
+ builder.emitStore(
+ indexedPtr,
+ builder.emitElementExtract(
+ swizzledStore->getSource(),
+ builder.getIntValue(builder.getIntType(), ii)));
+ }
+ instsToRemove.add(inst);
+ }
+ }
+ }
+
+ for (auto inst : instsToRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+}
+
SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
{
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
@@ -1626,6 +1671,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
+ lowerSwizzledStores(autoDiffSharedContext->moduleInst->getModule(), func);
+
auto result = eliminateAddressInsts(func, sink);
if (SLANG_SUCCEEDED(result))
@@ -1846,6 +1893,17 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_undefined:
return transcribeUndefined(builder, origInst);
+ // Differentiable insts that should have been lowered in a previous pass.
+ case kIROp_SwizzledStore:
+ {
+ // If we have a non-null dest ptr, then we error out because something went wrong
+ // when lowering swizzle-stores to regular stores
+ //
+ auto swizzledStore = as<IRSwizzledStore>(origInst);
+ SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr);
+ return transcribeNonDiffInst(builder, swizzledStore);
+ }
+
// Known non-differentiable insts.
case kIROp_Not:
case kIROp_BitAnd:
@@ -1875,13 +1933,13 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_DetachDerivative:
case kIROp_GetSequentialID:
case kIROp_GetStringHash:
- return trascribeNonDiffInst(builder, origInst);
+ return transcribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
// so we treat this inst as non differentiable.
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
case kIROp_CreateExistentialObject:
- return trascribeNonDiffInst(builder, origInst);
+ return transcribeNonDiffInst(builder, origInst);
case kIROp_StructKey:
return InstPair(origInst, nullptr);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 3f0036b06..fc8979551 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -218,7 +218,7 @@ namespace Slang
case kIROp_WrapExistential:
case kIROp_MakeExistential:
case kIROp_MakeExistentialWithRTTI:
- return trascribeNonDiffInst(builder, origInst);
+ return transcribeNonDiffInst(builder, origInst);
case kIROp_StructKey:
return InstPair(origInst, nullptr);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 0d39f879e..0a9ff51a4 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -872,7 +872,7 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc
return InstPair(diffBlock, diffBlock);
}
-InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
+InstPair AutoDiffTranscriberBase::transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
{
auto primal = cloneInst(&cloneEnv, builder, origInst);
return InstPair(primal, nullptr);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index d5ad29610..d6b2ea9ff 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -114,7 +114,7 @@ struct AutoDiffTranscriberBase
IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
- InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
+ InstPair transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn);
diff --git a/tests/autodiff/swizzled-store.slang b/tests/autodiff/swizzled-store.slang
new file mode 100644
index 000000000..58980616b
--- /dev/null
+++ b/tests/autodiff/swizzled-store.slang
@@ -0,0 +1,39 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float2> dpfloat2;
+typedef DifferentialPair<float3> dpfloat3;
+typedef DifferentialPair<float4> dpfloat4;
+
+[Differentiable]
+float2 f(float3 x)
+{
+ float3 u;
+ u.zy = x.yx;
+ return u.zy;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ float3 a = float3(2.0, 2.0, 2.0);
+ float3 da = float3(1.0, 0.5, 1.0);
+
+ outputBuffer[0] = fwd_diff(f)(dpfloat3(a, da)).d.x;
+ }
+
+ {
+ float3 a = float3(2.0, 2.0, 2.0);
+ var dpa = diffPair(a);
+
+ bwd_diff(f)(dpa, float2(0.5, 1.0));
+
+ outputBuffer[1] = dpa.d.x; // 1.0
+ outputBuffer[2] = dpa.d.y; // 0.5
+ outputBuffer[3] = dpa.d.z; // 0.0
+ }
+}
diff --git a/tests/autodiff/swizzled-store.slang.expected.txt b/tests/autodiff/swizzled-store.slang.expected.txt
new file mode 100644
index 000000000..8fe60c0db
--- /dev/null
+++ b/tests/autodiff/swizzled-store.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+0.500000
+1.000000
+0.500000
+0.000000