diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 62 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 2 |
4 files changed, 63 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index efaaec906..25051cb6d 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1617,6 +1617,51 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func) } } +bool isLocalPointer(IRInst* ptrInst) +{ + // If it's not a local var or a function parameter, then it's probably + // referencing something outside the function scope. + // + auto addr = getRootAddr(ptrInst); + return as<IRVar>(addr) || as<IRParam>(addr); +} + +void lowerSwizzledStores(IRModule* module, IRFunc* func) +{ + List<IRInst*> instsToRemove; + + IRBuilder builder(module); + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto swizzledStore = as<IRSwizzledStore>(inst)) + { + if (!isLocalPointer(swizzledStore->getDest())) + continue; + + builder.setInsertBefore(inst); + for (UIndex ii = 0; ii < swizzledStore->getElementCount(); ii++) + { + auto indexVal = swizzledStore->getElementIndex(ii); + auto indexedPtr = builder.emitElementAddress(swizzledStore->getDest(), indexVal); + builder.emitStore( + indexedPtr, + builder.emitElementExtract( + swizzledStore->getSource(), + builder.getIntValue(builder.getIntType(), ii))); + } + instsToRemove.add(inst); + } + } + } + + for (auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } +} + SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) { insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); @@ -1626,6 +1671,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func); + lowerSwizzledStores(autoDiffSharedContext->moduleInst->getModule(), func); + auto result = eliminateAddressInsts(func, sink); if (SLANG_SUCCEEDED(result)) @@ -1846,6 +1893,17 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_undefined: return transcribeUndefined(builder, origInst); + // Differentiable insts that should have been lowered in a previous pass. + case kIROp_SwizzledStore: + { + // If we have a non-null dest ptr, then we error out because something went wrong + // when lowering swizzle-stores to regular stores + // + auto swizzledStore = as<IRSwizzledStore>(origInst); + SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr); + return transcribeNonDiffInst(builder, swizzledStore); + } + // Known non-differentiable insts. case kIROp_Not: case kIROp_BitAnd: @@ -1875,13 +1933,13 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_DetachDerivative: case kIROp_GetSequentialID: case kIROp_GetStringHash: - return trascribeNonDiffInst(builder, origInst); + return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. case kIROp_CreateExistentialObject: - return trascribeNonDiffInst(builder, origInst); + return transcribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3f0036b06..fc8979551 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -218,7 +218,7 @@ namespace Slang case kIROp_WrapExistential: case kIROp_MakeExistential: case kIROp_MakeExistentialWithRTTI: - return trascribeNonDiffInst(builder, origInst); + return transcribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 0d39f879e..0a9ff51a4 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -872,7 +872,7 @@ InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBloc return InstPair(diffBlock, diffBlock); } -InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) +InstPair AutoDiffTranscriberBase::transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst) { auto primal = cloneInst(&cloneEnv, builder, origInst); return InstPair(primal, nullptr); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index d5ad29610..d6b2ea9ff 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -114,7 +114,7 @@ struct AutoDiffTranscriberBase IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); - InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst); + InstPair transcribeNonDiffInst(IRBuilder* builder, IRInst* origInst); InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn); |
