diff options
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 65 | ||||
| -rw-r--r-- | source/slang/slang-ir-addr-inst-elimination.cpp | 61 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 88 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 44 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-constexpr.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 52 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 190 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 85 | ||||
| -rw-r--r-- | tests/autodiff/reverse-addr-eliminate.slang | 14 | ||||
| -rw-r--r-- | tests/autodiff/reverse-addr-eliminate.slang.expected.txt | 2 |
17 files changed, 363 insertions, 327 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index ffb469b9d..160585e26 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1087,7 +1087,6 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) // Never fold these, because their result cannot be computed // as a sub-expression (they must be emitted as a declaration // or statement). - case kIROp_UpdateField: case kIROp_UpdateElement: case kIROp_DefaultConstruct: return false; @@ -2487,43 +2486,45 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) auto ii = (IRUpdateElement*)inst; auto subscriptOuter = getInfo(EmitOp::General); auto subscriptPrec = getInfo(EmitOp::Postfix); - auto arraySize = as<IRIntLit>(as<IRArrayType>(inst->getDataType())->getElementCount()); - SLANG_RELEASE_ASSERT(arraySize); emitInstResultDecl(inst); - m_writer->emit("{"); - for (UInt i = 0; i < (UInt)arraySize->getValue(); i++) + if (auto arrayType = as<IRArrayType>(inst->getDataType())) { - if (i > 0) - m_writer->emit(", "); - emitOperand(ii->getOldValue(), leftSide(subscriptOuter, subscriptPrec)); - m_writer->emit("["); - m_writer->emit(i); - m_writer->emit("]"); + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + SLANG_RELEASE_ASSERT(arraySize); + m_writer->emit("{"); + for (UInt i = 0; i < (UInt)arraySize->getValue(); i++) + { + if (i > 0) + m_writer->emit(", "); + emitOperand(ii->getOldValue(), leftSide(subscriptOuter, subscriptPrec)); + m_writer->emit("["); + m_writer->emit(i); + m_writer->emit("]"); + } + m_writer->emit("}"); + } + else + { + emitOperand(ii->getOldValue(), getInfo(EmitOp::General)); } - - m_writer->emit("}"); - m_writer->emit(";\n"); - - emitOperand(ii, leftSide(subscriptOuter, subscriptPrec)); - m_writer->emit("["); - emitOperand(ii->getIndex(), getInfo(EmitOp::General)); - m_writer->emit("] = "); - emitOperand(ii->getElementValue(), getInfo(EmitOp::General)); - m_writer->emit(";\n"); - } - break; - case kIROp_UpdateField: - { - auto ii = (IRUpdateField*)inst; - emitInstResultDecl(inst); - emitOperand(ii->getOldValue(), getInfo(EmitOp::General)); m_writer->emit(";\n"); - auto subscriptOuter = getInfo(EmitOp::General); - auto subscriptPrec = getInfo(EmitOp::Postfix); emitOperand(ii, leftSide(subscriptOuter, subscriptPrec)); - m_writer->emit("."); - m_writer->emit(getName(ii->getFieldKey())); + for (UInt i = 0; i < ii->getAccessKeyCount(); i++) + { + auto key = ii->getAccessKey(i); + if (as<IRStructKey>(key)) + { + m_writer->emit("."); + m_writer->emit(getName(key)); + } + else + { + m_writer->emit("["); + emitOperand(key, getInfo(EmitOp::General)); + m_writer->emit("]"); + } + } m_writer->emit(" = "); emitOperand(ii->getElementValue(), getInfo(EmitOp::General)); m_writer->emit(";\n"); diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 877be1406..a5e0e0a4e 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -36,72 +36,29 @@ struct AddressInstEliminationContext void storeValue(IRBuilder& builder, IRInst* addr, IRInst* val) { - List<IRInst*> baseAddrs; + List<IRInst*> accessChain; for (auto inst = addr; inst;) { switch (inst->getOp()) { default: - baseAddrs.add(inst); + accessChain.add(inst); goto endLoop; case kIROp_GetElementPtr: case kIROp_FieldAddress: - baseAddrs.add(inst); + accessChain.add(inst->getOperand(1)); inst = inst->getOperand(0); break; } } endLoop:; - List<IRInst*> values; - values.setCount(baseAddrs.getCount()); - if (values.getCount() > 1) - { - IRInst* currentVal = builder.emitLoad(baseAddrs.getLast()); - values.getLast() = currentVal; - for (Index i = baseAddrs.getCount() - 2; i >= 1; i--) - { - auto inst = baseAddrs[i]; - switch (inst->getOp()) - { - default: - sink->diagnose(inst->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff); - return; - case kIROp_GetElementPtr: - case kIROp_FieldAddress: - { - IRInst* args[] = { currentVal, inst->getOperand(1) }; - currentVal = builder.emitIntrinsicInst( - cast<IRPtrTypeBase>(inst->getFullType())->getValueType(), - (inst->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract), - 2, - args); - values[i] = currentVal; - } - break; - } - } - } - values[0] = val; - for (Index i = 1; i < values.getCount(); i++) - { - auto inst = baseAddrs[i - 1]; - switch (inst->getOp()) - { - case kIROp_GetElementPtr: - case kIROp_FieldAddress: - { - IRInst* args[] = {values[i], inst->getOperand(1), values[i - 1]}; - values[i] = builder.emitIntrinsicInst( - values[i]->getFullType(), - (inst->getOp() == kIROp_GetElementPtr ? kIROp_UpdateElement : kIROp_UpdateField), - 3, - args); - } - break; - } - } - builder.emitStore(baseAddrs.getLast(), values.getLast()); + auto lastAddr = accessChain.getLast(); + auto lastVal = builder.emitLoad(lastAddr); + accessChain.removeLast(); + accessChain.reverse(); + auto update = builder.emitUpdateElement(lastVal, accessChain, val); + builder.emitStore(lastAddr, update); } void transformLoadAddr(IRUse* use) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index f5fa17fae..58c8aae93 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -760,71 +760,42 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst return InstPair(primalGetElementPtr, diffGetElementPtr); } -InstPair ForwardDiffTranscriber::transcribeUpdateField(IRBuilder* builder, IRInst* originalInst) +InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst) { - auto updateInst = as<IRUpdateField>(originalInst); + auto updateInst = as<IRUpdateElement>(originalInst); IRInst* origBase = updateInst->getOldValue(); auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto field = updateInst->getFieldKey(); - auto primalVal = findOrTranscribePrimalInst(builder, updateInst->getElementValue()); - auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); - auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); - - IRInst* primalOperands[] = { primalBase, field, primalVal }; - IRInst* primalUpdateField = builder->emitIntrinsicInst( - primalType, - originalInst->getOp(), - 3, - primalOperands); - - if (!derivativeRefDecor) + List<IRInst*> primalAccessChain; + for (UInt i = 0; i < updateInst->getAccessKeyCount(); i++) { - return InstPair(primalUpdateField, nullptr); + auto originalKey = updateInst->getAccessKey(i); + auto primalKey = findOrTranscribePrimalInst(builder, originalKey); + primalAccessChain.add(primalKey); } + auto origVal = updateInst->getElementValue(); + auto primalVal = findOrTranscribePrimalInst(builder, origVal); - IRInst* diffUpdateField = nullptr; + IRInst* primalUpdateField = + builder->emitUpdateElement(primalBase, primalAccessChain, primalVal); - if (auto diffType = differentiateType(builder, originalInst->getDataType())) + IRInst* diffUpdateElement = nullptr; + List<IRInst*> diffAccessChain; + for (auto key : primalAccessChain) { - if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + if (as<IRStructKey>(key)) { - if (auto diffVal = findOrTranscribeDiffInst(builder, updateInst->getElementValue())) - { - auto primalElementType = primalVal->getDataType(); - - IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey(), diffVal, primalElementType }; - diffUpdateField = builder->emitIntrinsicInst( - diffType, - originalInst->getOp(), - 4, - diffOperands); - } + auto decor = key->findDecoration<IRDerivativeMemberDecoration>(); + if (decor) + diffAccessChain.add(decor->getDerivativeMemberStructKey()); + else + return InstPair(primalUpdateField, nullptr); + } + else + { + diffAccessChain.add(key); } } - return InstPair(primalUpdateField, diffUpdateField); -} - -InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst) -{ - auto updateInst = as<IRUpdateElement>(originalInst); - - IRInst* origBase = updateInst->getOldValue(); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto primalIndex = findOrTranscribePrimalInst(builder, updateInst->getIndex()); - auto origVal = updateInst->getElementValue(); - auto primalVal = findOrTranscribePrimalInst(builder, origVal); - auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); - - IRInst* primalOperands[] = { primalBase, primalIndex, primalVal }; - IRInst* primalUpdateField = builder->emitIntrinsicInst( - primalType, - originalInst->getOp(), - 3, - primalOperands); - - IRInst* diffUpdateElement = nullptr; - if (auto diffType = differentiateType(builder, originalInst->getDataType())) { if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) @@ -833,12 +804,9 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI { auto primalElementType = primalVal->getDataType(); - IRInst* diffOperands[] = { diffBase, primalIndex, diffVal, primalElementType }; - diffUpdateElement = builder->emitIntrinsicInst( - diffType, - originalInst->getOp(), - 4, - diffOperands); + diffUpdateElement = builder->emitUpdateElement( + diffBase, diffAccessChain, diffVal); + builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); } } } @@ -1249,8 +1217,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeByPassthrough(builder, origInst); case kIROp_UpdateElement: return transcribeUpdateElement(builder, origInst); - case kIROp_UpdateField: - return transcribeUpdateField(builder, origInst); case kIROp_unconditionalBranch: return transcribeControlFlow(builder, origInst); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 53577f40e..e595191a3 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -63,8 +63,6 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr); - InstPair transcribeUpdateField(IRBuilder* builder, IRInst* originalInst); - InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst); InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 000921c7e..fce2043eb 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -7,6 +7,9 @@ #include "slang-ir-inst-pass-base.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-autodiff-fwd.h" +#include "slang-ir-single-return.h" +#include "slang-ir-addr-inst-elimination.h" +#include "slang-ir-eliminate-multilevel-break.h" namespace Slang { @@ -483,6 +486,39 @@ namespace Slang builder.emitBranch(firstBlock); } + struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy + { + DifferentiableTypeConformanceContext* diffTypeContext; + + virtual bool shouldConvertAddrInst(IRInst* addrInst) override + { + if (isDifferentiableType(*diffTypeContext, addrInst->getDataType())) + return true; + return false; + } + }; + + SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func) + { + DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); + diffTypeContext.setFunc(func); + + if (!isSingleReturnFunc(func)) + { + convertFuncToSingleReturnForm(func->getModule(), func); + } + eliminateMultiLevelBreakForFunc(func->getModule(), func); + + AutoDiffAddressConversionPolicy cvtPolicty; + cvtPolicty.diffTypeContext = &diffTypeContext; + auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); + if (SLANG_SUCCEEDED(result)) + { + simplifyFunc(func); + } + return result; + } + // Create a copy of originalFunc's forward derivative in the same generic context (if any) of // `diffPropagateFunc`. IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( @@ -501,8 +537,10 @@ namespace Slang stripDerivativeDecorations(primalFunc); eliminateDeadCode(primalOuterParent); - // Perform simplification. - simplifyFunc(primalFunc); + // Perform required transformations and simplifications on the original func to make it + // reversible. + if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc))) + return diffPropagateFunc; // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 7aa6c2441..f789089b0 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -85,6 +85,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); + SlangResult prepareFuncForBackwardDiff(IRFunc* func); + IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc); void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 51dcd9f45..0d45c6a84 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1065,9 +1065,6 @@ struct DiffTransposePass case kIROp_UpdateElement: return transposeUpdateElement(builder, fwdInst, revValue); - case kIROp_UpdateField: - return transposeUpdateField(builder, fwdInst, revValue); - case kIROp_Specialize: case kIROp_unconditionalBranch: case kIROp_conditionalBranch: @@ -1312,20 +1309,22 @@ struct DiffTransposePass auto updateInst = as<IRUpdateElement>(fwdUpdate); List<RevGradient> gradients; - auto arrayType = cast<IRArrayType>(fwdUpdate->getFullType()); - auto revElement = builder->emitElementExtract(arrayType->getElementType(), revValue, updateInst->getIndex()); + auto accessChain = updateInst->getAccessChain(); + auto revElement = builder->emitElementExtract(revValue, accessChain.getArrayView()); gradients.add(RevGradient( RevGradient::Flavor::Simple, updateInst->getElementValue(), revElement, fwdUpdate)); - auto primalElementType = updateInst->getPrimalElementType(); - auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType); + auto primalElementTypeDecor = updateInst->findDecoration<IRPrimalElementTypeDecoration>(); + SLANG_RELEASE_ASSERT(primalElementTypeDecor); + + auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementTypeDecor->getPrimalElementType()); SLANG_ASSERT(diffZero); auto revRest = builder->emitUpdateElement( revValue, - updateInst->getIndex(), + accessChain, diffZero); gradients.add(RevGradient( RevGradient::Flavor::Simple, @@ -1336,35 +1335,6 @@ struct DiffTransposePass return TranspositionResult(gradients); } - TranspositionResult transposeUpdateField(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue) - { - auto updateInst = as<IRUpdateField>(fwdUpdate); - - List<RevGradient> gradients; - IRType* fieldType = updateInst->getElementValue()->getFullType(); - auto revElement = builder->emitFieldExtract(fieldType, revValue, updateInst->getFieldKey()); - gradients.add(RevGradient( - RevGradient::Flavor::Simple, - updateInst->getElementValue(), - revElement, - fwdUpdate)); - - auto primalElementType = updateInst->getPrimalElementType(); - auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType); - SLANG_ASSERT(diffZero); - auto revRest = builder->emitUpdateField( - revValue, - updateInst->getFieldKey(), - diffZero); - gradients.add(RevGradient( - RevGradient::Flavor::Simple, - updateInst->getOldValue(), - revRest, - fwdUpdate)); - // (A = UpdateField(s, fieldKey, V)) -> [(dV += dA.fieldKey, d_s += UpdateField(revValue, fieldKey, 0)] - return TranspositionResult(gradients); - } - // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr. // void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 8d9a01b75..44cb2aa09 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -196,7 +196,6 @@ struct ExtractPrimalFuncContext case kIROp_FieldExtract: case kIROp_swizzle: case kIROp_UpdateElement: - case kIROp_UpdateField: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: case kIROp_MatrixReshape: diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8cefa6a04..ce3e563f5 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -2,8 +2,6 @@ #include "slang-ir-autodiff.h" #include "slang-ir-inst-pass-base.h" -#include "slang-ir-single-return.h" -#include "slang-ir-addr-inst-elimination.h" namespace Slang { @@ -177,29 +175,6 @@ public: return false; } - struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy - { - DifferentiableTypeConformanceContext* diffTypeContext; - - virtual bool shouldConvertAddrInst(IRInst* addrInst) override - { - if (isDifferentiableType(*diffTypeContext, addrInst->getDataType())) - return true; - return false; - } - }; - - SlangResult prepareFuncForAutoDiff(DifferentiableTypeConformanceContext& diffTypeContext, IRFunc* func) - { - if (!isSingleReturnFunc(func)) - { - convertFuncToSingleReturnForm(func->getModule(), func); - } - AutoDiffAddressConversionPolicy cvtPolicty; - cvtPolicty.diffTypeContext = &diffTypeContext; - return eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); - } - void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -209,14 +184,6 @@ public: DifferentiableTypeConformanceContext diffTypeContext(&sharedContext); diffTypeContext.setFunc(funcInst); - if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) - { - if (auto func = as<IRFunc>(funcInst)) - { - if (SLANG_FAILED(prepareFuncForAutoDiff(diffTypeContext, func))) - return; - } - } HashSet<IRInst*> produceDiffSet; HashSet<IRInst*> expectDiffSet; diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index a8cdb5cca..2bbb9618c 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -119,7 +119,6 @@ bool opCanBeConstExpr(IROp op) case kIROp_swizzle: case kIROp_GetElement: case kIROp_FieldExtract: - case kIROp_UpdateField: case kIROp_UpdateElement: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 817edaa83..6b6b3924a 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -311,8 +311,7 @@ INST(Call, call, 1, 0) INST(RTTIObject, rtti_object, 0, 0) INST(Alloca, alloca, 1, 0) -INST(UpdateElement, updateElement, 3, 0) -INST(UpdateField, updateField, 3, 0) +INST(UpdateElement, updateElement, 2, 0) INST(PackAnyValue, packAnyValue, 1, 0) INST(UnpackAnyValue, unpackAnyValue, 1, 0) @@ -762,6 +761,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// in an intermediary struct for reuse in backward propagation phase. INST(PrimalValueStructKeyDecoration, primalValueKey, 1, 0) + /// Used by the auto-diff pass to mark the primal element type of an + /// forward-differentiated updateElement inst. + INST(PrimalElementTypeDecoration, primalElementType, 1, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 10c490f3c..405df4073 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -689,6 +689,18 @@ struct IRPrimalValueStructKeyDecoration : IRDecoration IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); } }; +struct IRPrimalElementTypeDecoration : IRDecoration +{ + enum + { + kOp = kIROp_PrimalElementTypeDecoration + }; + + IR_LEAF_ISA(PrimalElementTypeDecoration) + + IRInst* getPrimalElementType() { return getOperand(0); } +}; + struct IRMixedDifferentialInstDecoration : IRDecoration { enum @@ -2170,28 +2182,15 @@ struct IRUpdateElement : IRInst IR_LEAF_ISA(UpdateElement) IRInst* getOldValue() { return getOperand(0); } - IRInst* getIndex() { return getOperand(1); } - IRInst* getElementValue() { return getOperand(2); } - IRInst* getPrimalElementType() - { - if (getOperandCount() != 4) - return nullptr; - return getOperand(3); - } -}; - -struct IRUpdateField : IRInst -{ - IR_LEAF_ISA(UpdateField) - - IRInst* getOldValue() { return getOperand(0); } - IRInst* getFieldKey() { return getOperand(1); } - IRInst* getElementValue() { return getOperand(2); } - IRInst* getPrimalElementType() + IRInst* getElementValue() { return getOperand(1); } + IRInst* getAccessKey(UInt index) { return getOperand(2 + index); } + UInt getAccessKeyCount() { return getOperandCount() - 2; } + List<IRInst*> getAccessChain() { - if (getOperandCount() != 4) - return nullptr; - return getOperand(3); + List<IRInst*> result; + for (UInt i = 0; i < getAccessKeyCount(); i++) + result.add(getAccessKey(i)); + return result; } }; @@ -2798,6 +2797,7 @@ public: IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key); + IRInst* addPrimalElementTypeDecoration(IRInst* target, IRInst* type); // Add a differentiable type entry to the appropriate dictionary. IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness); @@ -3148,13 +3148,21 @@ public: IRInst* base, IRInst* index); + IRInst* emitElementExtract( + IRInst* base, + IRInst* index); + + IRInst* emitElementExtract( + IRInst* base, + const ArrayView<IRInst*>& accessChain); + IRInst* emitElementAddress( IRType* type, IRInst* basePtr, IRInst* index); IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement); - IRInst* emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal); + IRInst* emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement); IRInst* emitGetAddress( IRType* type, diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 16f6cd9b9..fd0b4577a 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -11,6 +11,91 @@ struct PeepholeContext : InstPassBase bool changed = false; + bool tryFoldElementExtractFromUpdateInst(IRInst* inst) + { + bool isAccessChainEqual = false; + bool isAccessChainNotEqual = false; + List<IRInst*> chainKey; + IRInst* chainNode = inst; + for (;;) + { + switch (chainNode->getOp()) + { + case kIROp_FieldExtract: + case kIROp_GetElement: + chainKey.add(chainNode->getOperand(1)); + chainNode = chainNode->getOperand(0); + continue; + } + break; + } + chainKey.reverse(); + if (auto updateInst = as<IRUpdateElement>(chainNode)) + { + if (updateInst->getAccessKeyCount() > (UInt)chainKey.getCount()) + return false; + + isAccessChainEqual = true; + for (UInt i = 0; i < (UInt)chainKey.getCount(); i++) + { + if (updateInst->getAccessKey(i) != chainKey[i]) + { + isAccessChainEqual = false; + if (as<IRStructKey>(chainKey[i])) + { + isAccessChainNotEqual = true; + break; + } + else + { + if (auto constIndex1 = as<IRIntLit>(updateInst->getAccessKey(i))) + { + if (auto constIndex2 = as<IRIntLit>(chainKey[i])) + { + if (constIndex1->getValue() != constIndex2->getValue()) + { + isAccessChainNotEqual = true; + break; + } + } + } + } + } + } + if (isAccessChainEqual) + { + auto remainingKeys = chainKey.getArrayView( + updateInst->getAccessKeyCount(), + chainKey.getCount() - updateInst->getAccessKeyCount()); + if (remainingKeys.getCount() == 0) + { + inst->replaceUsesWith(updateInst->getElementValue()); + inst->removeAndDeallocate(); + return true; + } + else if (remainingKeys.getCount() > 0) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto newValue = builder.emitElementExtract(updateInst->getElementValue(), remainingKeys); + inst->replaceUsesWith(newValue); + inst->removeAndDeallocate(); + return true; + } + } + else if (isAccessChainNotEqual) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + return false; + } + void processInst(IRInst* inst) { switch (inst->getOp()) @@ -84,19 +169,9 @@ struct PeepholeContext : InstPassBase } } } - else if (auto updateField = as<IRUpdateField>(inst->getOperand(0))) + else { - if (inst->getOperand(1) == updateField->getFieldKey()) - { - inst->replaceUsesWith(updateField->getElementValue()); - inst->removeAndDeallocate(); - changed = true; - } - else - { - inst->setOperand(0, updateField->getOldValue()); - changed = true; - } + changed = tryFoldElementExtractFromUpdateInst(inst); } break; case kIROp_GetElement: @@ -119,32 +194,18 @@ struct PeepholeContext : InstPassBase inst->removeAndDeallocate(); changed = true; } - else if (auto updateElement = as<IRUpdateElement>(inst->getOperand(0))) + else { - if (inst->getOperand(1) == updateElement->getIndex()) - { - inst->replaceUsesWith(updateElement->getElementValue()); - inst->removeAndDeallocate(); - changed = true; - } - else if (auto constIndex1 = as<IRIntLit>(inst->getOperand(1))) - { - if (auto constIndex2 = as<IRIntLit>(updateElement->getIndex())) - { - // If we can determine that the indices does not match, - // then reduce the original value operand to before the update. - if (constIndex1->getValue() != constIndex2->getValue()) - { - inst->setOperand(0, updateElement->getOldValue()); - changed = true; - } - } - } + changed = tryFoldElementExtractFromUpdateInst(inst); } break; case kIROp_UpdateElement: { - if (auto constIndex = as<IRIntLit>(inst->getOperand(1))) + auto updateInst = as<IRUpdateElement>(inst); + if (updateInst->getAccessKeyCount() != 1) + break; + auto key = updateInst->getAccessKey(0); + if (auto constIndex = as<IRIntLit>(key)) { auto oldVal = inst->getOperand(0); if (oldVal->getOp() == kIROp_MakeArray || @@ -179,44 +240,43 @@ struct PeepholeContext : InstPassBase } } } - } - break; - case kIROp_UpdateField: - { - auto oldVal = inst->getOperand(0); - if (oldVal->getOp() == kIROp_MakeStruct) + else if (auto structKey = as<IRStructKey>(key)) { - auto structType = as<IRStructType>(inst->getDataType()); - if (!structType) break; - List<IRInst*> args; - UInt i = 0; - bool isValid = true; - for (auto field : structType->getFields()) + auto oldVal = inst->getOperand(0); + if (oldVal->getOp() == kIROp_MakeStruct) { - IRInst* arg = nullptr; - if (i < oldVal->getOperandCount()) - arg = oldVal->getOperand(i); - if (field->getKey() == inst->getOperand(1)) - arg = inst->getOperand(2); - if (arg) + auto structType = as<IRStructType>(inst->getDataType()); + if (!structType) break; + List<IRInst*> args; + UInt i = 0; + bool isValid = true; + for (auto field : structType->getFields()) { - args.add(arg); + IRInst* arg = nullptr; + if (i < oldVal->getOperandCount()) + arg = oldVal->getOperand(i); + if (field->getKey() == inst->getOperand(1)) + arg = inst->getOperand(2); + if (arg) + { + args.add(arg); + } + else + { + isValid = false; + break; + } + i++; } - else + if (isValid) { - isValid = false; - break; + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer()); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + changed = true; } - i++; - } - if (isValid) - { - IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(inst); - auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer()); - inst->replaceUsesWith(makeStruct); - inst->removeAndDeallocate(); - changed = true; } } } diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index bcf0907df..9bd681115 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -38,7 +38,6 @@ struct RedundancyRemovalContext case kIROp_GetElement: case kIROp_GetElementPtr: case kIROp_UpdateElement: - case kIROp_UpdateField: case kIROp_LookupWitness: case kIROp_Specialize: case kIROp_OptionalHasValue: diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2960d942c..0434ff682 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3872,6 +3872,11 @@ namespace Slang return addDecoration(target, kIROp_PrimalValueStructKeyDecoration, key); } + IRInst* IRBuilder::addPrimalElementTypeDecoration(IRInst* target, IRInst* type) + { + return addDecoration(target, kIROp_PrimalElementTypeDecoration, type); + } + RefPtr<IRModule> IRModule::create(Session* session) { RefPtr<IRModule> module = new IRModule(session); @@ -4355,6 +4360,65 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitElementExtract( + IRInst* base, + IRInst* index) + { + IRType* type = nullptr; + if (auto arrayType = as<IRArrayType>(base->getDataType())) + { + type = arrayType->getElementType(); + } + else if (auto vectorType = as<IRVectorType>(base->getDataType())) + { + type = vectorType->getElementType(); + } + else if (auto matrixType = as<IRMatrixType>(base->getDataType())) + { + type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); + } + SLANG_RELEASE_ASSERT(type); + auto inst = createInst<IRFieldAddress>( + this, + kIROp_GetElement, + type, + base, + index); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitElementExtract( + IRInst* base, + const ArrayView<IRInst*>& accessChain) + { + for (auto access : accessChain) + { + IRType* resultType = nullptr; + if (auto structKey = as<IRStructKey>(access)) + { + auto structType = as<IRStructType>(base->getDataType()); + SLANG_RELEASE_ASSERT(structType); + for (auto field : structType->getFields()) + { + if (field->getKey() == structKey) + { + resultType = field->getFieldType(); + break; + } + } + SLANG_RELEASE_ASSERT(resultType); + base = emitFieldExtract(resultType, base, structKey); + } + else + { + base = emitElementExtract(base, access); + } + } + return base; + } + IRInst* IRBuilder::emitElementAddress( IRType* type, IRInst* basePtr, @@ -4378,23 +4442,21 @@ namespace Slang kIROp_UpdateElement, base->getFullType(), base, - index, - newElement); + newElement, + index); addInst(inst); return inst; } - IRInst* IRBuilder::emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal) + IRInst* IRBuilder::emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement) { - auto inst = createInst<IRUpdateField>( - this, - kIROp_UpdateField, - base->getFullType(), - base, - fieldKey, - newFieldVal); - + List<IRInst*> args; + args.add(base); + args.add(newElement); + args.addRange(accessChain); + auto inst = createInst<IRUpdateElement>( + this, kIROp_UpdateElement, base->getFullType(), (Int)args.getCount(), args.getBuffer()); addInst(inst); return inst; } @@ -6663,7 +6725,6 @@ namespace Slang case kIROp_GetElement: case kIROp_GetElementPtr: case kIROp_UpdateElement: - case kIROp_UpdateField: case kIROp_MeshOutputRef: case kIROp_MakeVectorFromScalar: case kIROp_swizzle: diff --git a/tests/autodiff/reverse-addr-eliminate.slang b/tests/autodiff/reverse-addr-eliminate.slang index daa6fa32b..e23e83e6a 100644 --- a/tests/autodiff/reverse-addr-eliminate.slang +++ b/tests/autodiff/reverse-addr-eliminate.slang @@ -4,6 +4,11 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +struct D : IDifferentiable +{ + float n; + float m; +} struct C : IDifferentiable { float3 t; @@ -22,6 +27,7 @@ struct A : IDifferentiable float y; B fb; C aarr[3]; + D dv; }; [BackwardDifferentiable] @@ -33,7 +39,9 @@ A f(A a, int i) aout.x = aout.y + 5 * a.x; aout.aarr[1].t = float3(a.y, 0.0, a.x); aout.aarr[1].t = float3(a.y, 1.0, a.x + 1.0); - + D nd = { a.x * 4.0f, 1.0f }; + aout.dv = nd; + aout.dv.m = aout.dv.n * 0.5f; // Test that writes to a potentially dynamic address multiple times // is allowed and will propagate the correct derivative. aout.fb.arr[i].v = a.x * 2.0; // since this value is overwritten, the diff will not accumulate to a.x @@ -48,9 +56,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) var dpa = diffPair(a); - A.Differential dout = { 1.0, 1.0, { float2(0), { { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }, { { float3(1.0), 1.0 }, { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }; + A.Differential dout = { 1.0, 1.0, { float2(0), { { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }, { { float3(1.0), 1.0 }, { float3(1.0), 1.0 }, { float3(1.0), 1.0 } }, {1.0, 1.0} }; __bwd_diff(f)(dpa, 1, dout); - outputBuffer[0] = dpa.d.x; // Expect: 17 + outputBuffer[0] = dpa.d.x; // Expect: 23 outputBuffer[1] = dpa.d.y; // Expect: 0 } diff --git a/tests/autodiff/reverse-addr-eliminate.slang.expected.txt b/tests/autodiff/reverse-addr-eliminate.slang.expected.txt index dd367f3f5..fddc3120a 100644 --- a/tests/autodiff/reverse-addr-eliminate.slang.expected.txt +++ b/tests/autodiff/reverse-addr-eliminate.slang.expected.txt @@ -1,5 +1,5 @@ type: float -17.000000 +23.000000 1.000000 0.000000 0.000000 |
