summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp62
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
4 files changed, 63 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index efaaec906..25051cb6d 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1617,6 +1617,51 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
}
}
+bool isLocalPointer(IRInst* ptrInst)
+{
+ // If it's not a local var or a function parameter, then it's probably
+ // referencing something outside the function scope.
+ //
+ auto addr = getRootAddr(ptrInst);
+ return as<IRVar>(addr) || as<IRParam>(addr);
+}
+
+void lowerSwizzledStores(IRModule* module, IRFunc* func)
+{
+ List<IRInst*> instsToRemove;
+
+ IRBuilder builder(module);
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (auto swizzledStore = as<IRSwizzledStore>(inst))
+ {
+ if (!isLocalPointer(swizzledStore->getDest()))
+ continue;
+
+ builder.setInsertBefore(inst);
+ for (UIndex ii = 0; ii < swizzledStore->getElementCount(); ii++)
+ {
+ auto indexVal = swizzledStore->getElementIndex(ii);
+ auto indexedPtr = builder.emitElementAddress(swizzledStore->getDest(), indexVal);
+ builder.emitStore(
+ indexedPtr,
+ builder.emitElementExtract(
+ swizzledStore->getSource(),
+ builder.getIntValue(builder.getIntType(), ii)));
+ }
+ instsToRemove.add(inst);
+ }
+ }
+ }
+
+ for (auto inst : instsToRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+}
+
SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
{
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
@@ -1626,6 +1671,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
+ lowerSwizzledStores(autoDiffSharedContext->moduleInst->getModule(), func);
+
auto result = eliminateAddressInsts(func, sink);
if (SLANG_SUCCEEDED(result))
@@ -1846,6 +1893,17 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_undefined:
return transcribeUndefined(builder, origInst);
+ // Differentiable insts that should have been lowered in a previous pass.
+ case kIROp_SwizzledStore:
+ {
+ // If we have a non-null dest ptr, then we error out because something went wrong
+ // when lowering swizzle-stores to regular stores
+ //
+ auto swizzledStore = as<IRSwizzledStore>(origInst);
+ SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr);
+ return transcribeNonDiffInst(builder, swizzledStore);
+ }
+
// Known non-differentiable insts.
case kIROp_Not:
case kIROp_BitAnd:
@@ -1875,13 +1933,13 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_DetachDerivative:
case kIROp_GetSequentialID:
case kIROp_GetStringHash:
- return trascribeNonDiffInst(builder, origInst);
+ return transcribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
// so we treat this inst as non differentiable.
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
case kIROp_CreateExistentialObject:
- return trascribeNonDiffInst(builder, origInst);
+ return transcribeNonDiffInst(builder, origInst);
case kIROp_StructKey:
return InstPair(origInst, nullptr);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 3f0036b06..fc8979551 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -218,7 +218,7 @@ namespace Slang
case kIROp_WrapExistential:
case kIROp_MakeExistential:
case kIROp_MakeExistentialWithRTTI:
- return trascribeNonDiffInst(builder, origInst);
+ return transcribeNonDiffInst(builder, origInst);
case kIROp_StructKey:
return InstPair(origInst, nullptr);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 0d39f879e..0a9ff51a4 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -872,7 +872,7 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc
return InstPair(diffBlock, diffBlock);
}
-InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
+InstPair AutoDiffTranscriberBase::transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
{
auto primal = cloneInst(&cloneEnv, builder, origInst);
return InstPair(primal, nullptr);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index d5ad29610..d6b2ea9ff 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -114,7 +114,7 @@ struct AutoDiffTranscriberBase
IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
- InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
+ InstPair transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn);