From a911ca6e06ce41e403b80fe6054162393491c8ac Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Mar 2023 10:57:28 -0700 Subject: Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695) * Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 12 +- source/slang/slang-check-decl.cpp | 42 ++++-- source/slang/slang-ir-autodiff-fwd.cpp | 160 ++++++++++++++++----- source/slang/slang-ir-autodiff-fwd.h | 4 +- source/slang/slang-ir-autodiff-pairs.cpp | 123 ++++++++++++++-- source/slang/slang-ir-autodiff-pairs.h | 6 +- source/slang/slang-ir-autodiff-rev.cpp | 75 +--------- .../slang/slang-ir-autodiff-transcriber-base.cpp | 21 ++- source/slang/slang-ir-autodiff-transpose.h | 118 ++++++++++++++- source/slang/slang-ir-autodiff.cpp | 51 ++++--- source/slang/slang-ir-autodiff.h | 19 ++- source/slang/slang-ir-check-differentiability.cpp | 15 +- source/slang/slang-ir-inst-defs.h | 11 ++ source/slang/slang-ir-insts.h | 44 +++++- source/slang/slang-ir-util.cpp | 1 + source/slang/slang-ir.cpp | 62 ++++++++ source/slang/slang-ir.h | 12 +- source/slang/slang-lower-to-ir.cpp | 16 ++- source/slang/slang-syntax.cpp | 19 +++ 19 files changed, 629 insertions(+), 182 deletions(-) (limited to 'source/slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 4301eda94..ada052cd8 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -39,30 +39,30 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute; __generic __magic_type(DifferentialPairType) -__intrinsic_type($(kIROp_DifferentialPairType)) +__intrinsic_type($(kIROp_DifferentialPairUserCodeType)) struct DifferentialPair : IDifferentiable { typedef DifferentialPair Differential; typedef T.Differential DifferentialElementType; - __intrinsic_op($(kIROp_MakeDifferentialPair)) + __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) __init(T _primal, T.Differential _differential); property p : T { - __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) + __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode)) get; } property v : T { - __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) + __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode)) get; } property d : T.Differential { - __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) + __intrinsic_op($(kIROp_DifferentialPairGetDifferentialUserCode)) get; } @@ -105,7 +105,7 @@ struct DifferentialPair : IDifferentiable }; __generic -__intrinsic_op($(kIROp_MakeDifferentialPair)) +__intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) DifferentialPair diffPair(T primal, T.Differential diff); __generic diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5cd7fba45..ea8bec2bb 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1506,19 +1506,37 @@ namespace Slang aggTypeDecl->members.add(diffField); aggTypeDecl->invalidateMemberDictionary(); + // Inject a `DerivativeMember` modifier on the differential field to point to itself. + { + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->create(); + baseTypeType->type = differentialType; + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(diffField, derivativeMemberModifier); + } + // Inject a `DerivativeMember` modifier on the original decl. - auto derivativeMemberModifier = m_astBuilder->create(); - auto fieldLookupExpr = m_astBuilder->create(); - fieldLookupExpr->type.type = diffMemberType; - auto baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->create(); - baseTypeType->type = differentialType; - baseTypeExpr->type.type = baseTypeType; - fieldLookupExpr->baseExpression = baseTypeExpr; - fieldLookupExpr->declRef = makeDeclRef(diffField); - derivativeMemberModifier->memberDeclRef = fieldLookupExpr; - addModifier(member, derivativeMemberModifier); + { + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->create(); + baseTypeType->type = differentialType; + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } }; // Make the Differential type itself conform to `IDifferential` interface. diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 2090cd4dc..7057a5835 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -8,6 +8,9 @@ #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-single-return.h" +#include "slang-ir-addr-inst-elimination.h" +#include "slang-ir-ssa-simplification.h" +#include "slang-ir-validate.h" namespace Slang { @@ -234,7 +237,7 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig auto primalElement = builder->emitDifferentialPairGetPrimal(load); auto diffElement = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); return InstPair(primalElement, diffElement); } } @@ -938,7 +941,10 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI if (decor) diffAccessChain.add(decor->getDerivativeMemberStructKey()); else - return InstPair(primalUpdateField, nullptr); + { + auto diffBase = findOrTranscribeDiffInst(builder, origBase); + return InstPair(primalUpdateField, diffBase); + } } else { @@ -947,24 +953,26 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI } if (auto diffType = differentiateType(builder, originalInst->getDataType())) { - if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + auto diffBase = findOrTranscribeDiffInst(builder, origBase); + if (!diffBase) { - if (auto diffVal = findOrTranscribeDiffInst(builder, origVal)) - { - auto primalElementType = primalVal->getDataType(); + diffBase = getDifferentialZeroOfType(builder, origBase->getDataType()); + } + if (auto diffVal = findOrTranscribeDiffInst(builder, origVal)) + { + auto primalElementType = primalVal->getDataType(); - diffUpdateElement = builder->emitUpdateElement( - diffBase, diffAccessChain, diffVal); - builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); - } - else - { - auto primalElementType = primalVal->getDataType(); - auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType); - diffUpdateElement = builder->emitUpdateElement( - diffBase, diffAccessChain, zeroElementDiff); - builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); - } + diffUpdateElement = builder->emitUpdateElement( + diffBase, diffAccessChain, diffVal); + builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); + } + else + { + auto primalElementType = primalVal->getDataType(); + auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType); + diffUpdateElement = builder->emitUpdateElement( + diffBase, diffAccessChain, zeroElementDiff); + builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); } } return InstPair(primalUpdateField, diffUpdateElement); @@ -1121,7 +1129,7 @@ InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* return InstPair(diffIfElse, diffIfElse); } -InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) +InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst) { auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); SLANG_ASSERT(primalVal); @@ -1140,9 +1148,9 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType()); auto diffPairType = findOrTranscribeDiffInst(builder, origInst->getFullType()); - auto primalPair = builder->emitMakeDifferentialPair( + auto primalPair = builder->emitMakeDifferentialPairUserCode( (IRType*)primalPairType, primalVal, diffPrimalVal); - auto diffPair = builder->emitMakeDifferentialPair( + auto diffPair = builder->emitMakeDifferentialPairUserCode( (IRType*)diffPairType, primalDiffVal, diffDiffVal); @@ -1152,8 +1160,8 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) { SLANG_ASSERT( - origInst->getOp() == kIROp_DifferentialPairGetDifferential || - origInst->getOp() == kIROp_DifferentialPairGetPrimal); + origInst->getOp() == kIROp_DifferentialPairGetDifferentialUserCode || + origInst->getOp() == kIROp_DifferentialPairGetPrimalUserCode); auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); SLANG_ASSERT(primalVal); @@ -1165,10 +1173,10 @@ InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* auto primalResult = builder->emitIntrinsicInst((IRType*)primalType, origInst->getOp(), 1, &primalVal); - auto diffValPairType = as(diffVal->getDataType()); + auto diffValPairType = as(diffVal->getDataType()); IRInst* diffResultType = nullptr; - if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) - diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); + if (origInst->getOp() == kIROp_DifferentialPairGetDifferentialUserCode) + diffResultType = differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffValPairType); else diffResultType = diffValPairType->getValueType(); auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); @@ -1318,6 +1326,8 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I // Mark the generated derivative function itself as differentiable. builder.addForwardDifferentiableDecoration(diffFunc); + if (isBackwardDifferentiableFunc(origFunc)) + builder.addBackwardDifferentiableDecoration(diffFunc); // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. if (auto dictDecor = origFunc->findDecoration()) @@ -1349,23 +1359,105 @@ void ForwardDiffTranscriber::checkAutodiffInstDecorations(IRFunc* fwdFunc) } } +void insertTempVarForMutableParams(IRModule* module, IRFunc* func) +{ + IRBuilder builder(module); + auto firstBlock = func->getFirstBlock(); + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + OrderedDictionary mapParamToTempVar; + List params; + for (auto param : firstBlock->getParams()) + { + if (auto ptrType = as(param->getDataType())) + { + params.add(param); + } + } + + for (auto param : params) + { + auto ptrType = as(param->getDataType()); + auto tempVar = builder.emitVar(ptrType->getValueType()); + param->replaceUsesWith(tempVar); + mapParamToTempVar[param] = tempVar; + if (ptrType->getOp() != kIROp_OutType) + { + builder.emitStore(tempVar, builder.emitLoad(param)); + } + else + { + builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType())); + } + } + + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + builder.setInsertBefore(inst); + for (auto& kv : mapParamToTempVar) + { + builder.emitStore(kv.Key, builder.emitLoad(kv.Value)); + } + } + } + } +} + +struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy +{ + DifferentiableTypeConformanceContext* diffTypeContext; + + virtual bool shouldConvertAddrInst(IRInst*) override + { + return true; + } +}; + +SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) +{ + insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); + + AutoDiffAddressConversionPolicy cvtPolicty; + cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext; + auto result = eliminateAddressInsts(&cvtPolicty, func, sink); + + if (SLANG_SUCCEEDED(result)) + { + disableIRValidationAtInsert(); + simplifyFunc(func); + enableIRValidationAtInsert(); + } + return result; +} + // Transcribe a function definition. InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { IRBuilder builder = *inBuilder; + builder.setInsertBefore(primalFunc); + + // Create a clone for original func and run additional transformations on the clone. + IRCloneEnv env; + auto primalFuncClone = as(cloneInst(&env, &builder, primalFunc)); + prepareFuncForForwardDiff(primalFuncClone); + builder.setInsertInto(diffFunc); - differentiableTypeConformanceContext.setFunc(primalFunc); + differentiableTypeConformanceContext.setFunc(primalFuncClone); mapInOutParamToWriteBackValue.Clear(); // Transcribe children from origFunc into diffFunc - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock()) this->transcribe(&builder, block); // Some of the transcribed blocks can appear 'out-of-order'. Although this // shouldn't be an issue, for consistency, we put them back in order. - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock()) as(lookupDiffInst(block))->insertAtEnd(diffFunc); for (auto block : diffFunc->getBlocks()) @@ -1507,11 +1599,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Switch: return transcribeSwitch(builder, as(origInst)); - case kIROp_MakeDifferentialPair: - return transcribeMakeDifferentialPair(builder, as(origInst)); + case kIROp_MakeDifferentialPairUserCode: + return transcribeMakeDifferentialPair(builder, as(origInst)); - case kIROp_DifferentialPairGetPrimal: - case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPairGetDifferentialUserCode: return transcribeDifferentialPairGetElement(builder, origInst); case kIROp_ExtractExistentialValue: @@ -1612,7 +1704,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType), diffPairParam)); } else if (auto pairPtrType = as(diffPairType)) diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 5b79a6c54..6032c2319 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -77,7 +77,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch); - InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst); + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst); InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); @@ -100,6 +100,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase void checkAutodiffInstDecorations(IRFunc* fwdFunc); + SlangResult prepareFuncForForwardDiff(IRFunc* func); + // Create an empty func to represent the transcribed func of `origFunc`. virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index 9d761764c..7b16c0213 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -24,10 +24,10 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) { - if (auto makePairInst = as(inst)) + if (auto makePairInst = as(inst)) { bool isTrivial = false; - auto pairType = as(makePairInst->getDataType()); + auto pairType = as(makePairInst->getDataType()); if (auto loweredPairType = lowerPairType(builder, pairType)) { builder->setInsertBefore(makePairInst); @@ -52,7 +52,7 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) { - if (auto getDiffInst = as(inst)) + if (auto getDiffInst = as(inst)) { auto pairType = getDiffInst->getBase()->getDataType(); if (auto pairPtrType = as(pairType)) @@ -70,7 +70,7 @@ struct DiffPairLoweringPass : InstPassBase return diffFieldExtract; } } - else if (auto getPrimalInst = as(inst)) + else if (auto getPrimalInst = as(inst)) { auto pairType = getPrimalInst->getBase()->getDataType(); if (auto pairPtrType = as(pairType)) @@ -106,10 +106,12 @@ struct DiffPairLoweringPass : InstPassBase { case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferentialUserCode: + case kIROp_DifferentialPairGetPrimalUserCode: lowerPairAccess(builder, inst); break; - case kIROp_MakeDifferentialPair: + case kIROp_MakeDifferentialPairUserCode: lowerMakePair(builder, inst); break; @@ -119,12 +121,15 @@ struct DiffPairLoweringPass : InstPassBase }); OrderedDictionary pendingReplacements; - processInstsOfType(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + processAllInsts([&](IRInst* inst) { - if (auto loweredType = lowerPairType(builder, inst)) + if (auto pairType = as(inst)) { - pendingReplacements.Add(inst, loweredType); - modified = true; + if (auto loweredType = lowerPairType(builder, pairType)) + { + pendingReplacements.Add(pairType, loweredType); + modified = true; + } } }); for (auto replacement : pendingReplacements) @@ -158,4 +163,104 @@ bool processPairTypes(AutoDiffSharedContext* context) return pairLoweringPass.processModule(); } +struct DifferentialPairUserCodeTranscribePass : public InstPassBase +{ + DifferentialPairUserCodeTranscribePass(IRModule* module) + :InstPassBase(module) + {} + + IRInst* rewritePairType(IRBuilder* builder, IRType* pairType) + { + builder->setInsertBefore(pairType); + auto originalPairType = as(pairType); + return builder->getDifferentialPairUserCodeType(originalPairType->getValueType(), originalPairType->getWitness()); + } + + IRInst* rewriteMakePair(IRBuilder* builder, IRMakeDifferentialPair* inst) + { + auto pairType = as(inst->getFullType()); + builder->setInsertBefore(inst); + auto newInst = builder->emitMakeDifferentialPairUserCode( + (IRType*)pairType, inst->getPrimalValue(), inst->getDifferentialValue()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return newInst; + } + + IRInst* rewritePairAccess(IRBuilder* builder, IRInst* inst) + { + if (auto getDiffInst = as(inst)) + { + builder->setInsertBefore(inst); + + auto newInst = builder->emitDifferentialPairGetDifferentialUserCode( + (IRType*)inst->getFullType(), getDiffInst->getBase()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + else if (auto getPrimalInst = as(inst)) + { + builder->setInsertBefore(inst); + auto newInst = builder->emitDifferentialPairGetPrimalUserCode(getPrimalInst->getBase()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + return inst; + } + + bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren) + { + SLANG_UNUSED(instWithChildren); + + bool modified = false; + + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + rewritePairAccess(builder, inst); + break; + + case kIROp_MakeDifferentialPair: + rewriteMakePair(builder, as(inst)); + break; + + default: + break; + } + }); + + OrderedDictionary pendingReplacements; + processInstsOfType(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + { + if (auto loweredType = rewritePairType(builder, inst)) + { + pendingReplacements.Add(inst, loweredType); + modified = true; + } + }); + for (auto replacement : pendingReplacements) + { + replacement.Key->replaceUsesWith(replacement.Value); + replacement.Key->removeAndDeallocate(); + } + + return modified; + } + + bool processModule() + { + IRBuilder builder(module); + return processInstWithChildren(&builder, module->getModuleInst()); + } +}; + +void rewriteDifferentialPairToUserCode(IRModule* module) +{ + DifferentialPairUserCodeTranscribePass pairRewritePass(module); + pairRewritePass.processModule(); +} + } diff --git a/source/slang/slang-ir-autodiff-pairs.h b/source/slang/slang-ir-autodiff-pairs.h index 44321ae9b..8f9e77145 100644 --- a/source/slang/slang-ir-autodiff-pairs.h +++ b/source/slang/slang-ir-autodiff-pairs.h @@ -18,4 +18,8 @@ namespace Slang bool processPairTypes(AutoDiffSharedContext* context); -} \ No newline at end of file +// Rewrites all uses of `DifferentialPairType` into `DifferentialPairUserCodeType` in the original func, +// so they are not to be confused with real mixed differential code generated by forward diff pass. +void rewriteDifferentialPairToUserCode(IRModule* module); + +} diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index d7cce7c53..328af4867 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -2,14 +2,12 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" -#include "slang-ir-eliminate-phis.h" #include "slang-ir-autodiff-cfg-norm.h" #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-single-return.h" -#include "slang-ir-addr-inst-elimination.h" #include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-init-local-var.h" #include "slang-ir-redundancy-removal.h" @@ -516,65 +514,6 @@ namespace Slang builder.emitBranch(firstBlock); } - void insertTempVarForMutableParams(IRModule* module, IRFunc* func) - { - IRBuilder builder(module); - auto firstBlock = func->getFirstBlock(); - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - OrderedDictionary mapParamToTempVar; - List params; - for (auto param : firstBlock->getParams()) - { - if (auto ptrType = as(param->getDataType())) - { - params.add(param); - } - } - - for (auto param : params) - { - auto ptrType = as(param->getDataType()); - auto tempVar = builder.emitVar(ptrType->getValueType()); - param->replaceUsesWith(tempVar); - mapParamToTempVar[param] = tempVar; - if (ptrType->getOp() != kIROp_OutType) - { - builder.emitStore(tempVar, builder.emitLoad(param)); - } - else - { - builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType())); - } - } - - for (auto block : func->getBlocks()) - { - for (auto inst : block->getChildren()) - { - if (inst->getOp() == kIROp_Return) - { - builder.setInsertBefore(inst); - for (auto& kv : mapParamToTempVar) - { - builder.emitStore(kv.Key, builder.emitLoad(kv.Value)); - } - } - } - } - } - - - struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy - { - DifferentiableTypeConformanceContext* diffTypeContext; - - virtual bool shouldConvertAddrInst(IRInst*) override - { - return true; - } - }; - SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func) { DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); @@ -592,19 +531,7 @@ namespace Slang IRCFGNormalizationPass cfgPass = {this->getSink()}; normalizeCFG(autoDiffSharedContext->moduleInst->getModule(), func); - insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); - - AutoDiffAddressConversionPolicy cvtPolicty; - cvtPolicty.diffTypeContext = &diffTypeContext; - auto result = eliminateAddressInsts(&cvtPolicty, func, sink); - - if (SLANG_SUCCEEDED(result)) - { - disableIRValidationAtInsert(); - simplifyFunc(func); - enableIRValidationAtInsert(); - } - return result; + return SLANG_OK; } // Create a copy of originalFunc's forward derivative in the same generic context (if any) of diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index ed122c862..091e7f1ab 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -304,8 +304,16 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy auto primalPairType = as(primalType); return getOrCreateDiffPairType( builder, - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), + differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_DifferentialPairUserCodeType: + { + auto primalPairType = as(primalType); + return builder->getDifferentialPairUserCodeType( + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), + differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(builder, primalPairType)); } case kIROp_FuncType: @@ -634,6 +642,15 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I builder->markInstAsDifferential(makeDiffPair, as(diffType)->getValueType()); return makeDiffPair; } + case kIROp_DifferentialPairUserCodeType: + { + auto makeDiffPair = builder->emitMakeDifferentialPairUserCode( + diffType, + getDifferentialZeroOfType(builder, as(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + builder->markInstAsDifferential(makeDiffPair, as(diffType)->getValueType()); + return makeDiffPair; + } } if (auto arrayType = as(primalType)) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 23f57032d..1cd6a0e33 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -24,7 +24,7 @@ struct DiffTransposePass GetElement, GetDifferential, FieldExtract, - + DifferentialPairGetElementUserCode, Invalid }; @@ -1704,7 +1704,16 @@ struct DiffTransposePass case kIROp_DifferentialPairGetDifferential: return transposeGetDifferential(builder, as(fwdInst), revValue); - + + case kIROp_MakeDifferentialPairUserCode: + return transposeMakePairUserCode(builder, as(fwdInst), revValue); + + case kIROp_DifferentialPairGetPrimalUserCode: + return transposeGetPrimalUserCode(builder, as(fwdInst), revValue); + + case kIROp_DifferentialPairGetDifferentialUserCode: + return transposeGetDifferentialUserCode(builder, as(fwdInst), revValue); + case kIROp_MakeVector: return transposeMakeVector(builder, fwdInst, revValue); case kIROp_MakeVectorFromScalar: @@ -1878,6 +1887,47 @@ struct DiffTransposePass fwdGetDiff))); } + TranspositionResult transposeMakePairUserCode(IRBuilder* builder, IRMakeDifferentialPairUserCode* fwdMakePair, IRInst* revValue) + { + List gradients; + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakePair->getPrimalValue(), + builder->emitDifferentialPairGetPrimalUserCode(revValue), + fwdMakePair)); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakePair->getDifferentialValue(), + builder->emitDifferentialPairGetDifferentialUserCode( + fwdMakePair->getDifferentialValue()->getFullType(), revValue), + fwdMakePair)); + return TranspositionResult(gradients); + } + + TranspositionResult transposeGetDifferentialUserCode(IRBuilder*, IRDifferentialPairGetDifferentialUserCode* fwdGetDiff, IRInst* revValue) + { + // (A = x.p) -> (dX = DiffPairUserCode(dA, 0)) + return TranspositionResult( + List( + RevGradient( + RevGradient::Flavor::DifferentialPairGetElementUserCode, + fwdGetDiff->getBase(), + revValue, + fwdGetDiff))); + } + + TranspositionResult transposeGetPrimalUserCode(IRBuilder*, IRDifferentialPairGetPrimalUserCode* fwdGetPrimal, IRInst* revValue) + { + // (A = x.p) -> (dX = DiffPairUserCode(0, dA)) + return TranspositionResult( + List( + RevGradient( + RevGradient::Flavor::DifferentialPairGetElementUserCode, + fwdGetPrimal->getBase(), + revValue, + fwdGetPrimal))); + } + TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) { auto vectorType = as(revValue->getDataType()); @@ -2497,6 +2547,40 @@ struct DiffTransposePass return materializeSimpleGradients(builder, aggPrimalType, simpleGradients); } + RevGradient materializeDifferentialPairUserCodeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List gradients) + { + List simpleGradients; + + for (auto gradient : gradients) + { + // Peek at the fwd-mode get element inst to see what type we need to materialize. + if (auto fwdGetDiff = as(gradient.fwdGradInst)) + { + auto baseType = as(diffTypeContext.getDifferentialForType( + builder, + fwdGetDiff->getBase()->getDataType())); + simpleGradients.add( + RevGradient( + gradient.targetInst, + builder->emitMakeDifferentialPairUserCode(baseType, emitDZeroOfDiffInstType(builder, baseType->getValueType()), gradient.revGradInst), + gradient.fwdGradInst)); + } + else if (auto fwdGetPrimal = as(gradient.fwdGradInst)) + { + auto baseType = as(diffTypeContext.getDifferentialForType( + builder, + fwdGetPrimal->getBase()->getDataType())); + simpleGradients.add( + RevGradient( + gradient.targetInst, + builder->emitMakeDifferentialPairUserCode(baseType, gradient.revGradInst, emitDZeroOfDiffInstType(builder, fwdGetPrimal->getFullType())), + gradient.fwdGradInst)); + } + } + + return materializeSimpleGradients(builder, aggPrimalType, simpleGradients); + } + RevGradient materializeGradientSet(IRBuilder* builder, IRType* aggPrimalType, List gradients) { switch (gradients[0].flavor) @@ -2513,6 +2597,9 @@ struct DiffTransposePass case RevGradient::Flavor::GetElement: return materializeGetElementGradients(builder, aggPrimalType, gradients); + case RevGradient::Flavor::DifferentialPairGetElementUserCode: + return materializeDifferentialPairUserCodeGetElementGradients(builder, aggPrimalType, gradients); + default: SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization"); } @@ -2773,6 +2860,16 @@ struct DiffTransposePass auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType()); return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero); } + else if (auto diffPairUserType = as(primalType)) + { + auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType()); + auto diffZero = primalZero; + auto diffType = primalZero->getFullType(); + auto diffWitness = diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType); + auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); + return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero); + } + auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); // Should exist. @@ -2810,6 +2907,23 @@ struct DiffTransposePass SLANG_UNIMPLEMENTED_X("dadd of dynamic array."); } } + else if (auto diffPairUserType = as(primalType)) + { + auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, diffPairUserType); + auto diffWitness = diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType); + + auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1); + auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2); + auto primal = emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2); + + auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1); + auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2); + auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2); + + auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); + return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff); + } + auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType); // Should exist. diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 65e880868..edea3847d 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -44,6 +44,18 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK return nullptr; } +static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); +} + +static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); +} + bool isNoDiffType(IRType* paramType) { while (auto ptrType = as(paramType)) @@ -266,25 +278,13 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I return pairStructType; } -IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) -{ - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); -} - -IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) -{ - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); -} - IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; if (pairTypeCache.TryGetValue(originalPairType, result)) return result; - auto pairType = as(originalPairType); + auto pairType = as(originalPairType); if (!pairType) { result = originalPairType; @@ -297,7 +297,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } - auto diffType = getDiffTypeFromPairType(builder, pairType); + auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); if (!diffType) return result; result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); @@ -406,18 +406,28 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b } IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType( - IRBuilder* builder, IRDifferentialPairType* diffPairType) + IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType) { auto witness = diffPairType->getWitness(); SLANG_RELEASE_ASSERT(witness); return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); } +IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + return _getDiffTypeFromPairType(sharedContext, builder, type); +} + +IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + return _getDiffTypeWitnessFromPairType(sharedContext, builder, type); +} + void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { for (auto globalInst : sharedContext->moduleInst->getChildren()) { - if (auto pairType = as(globalInst)) + if (auto pairType = as(globalInst)) { differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness()); } @@ -505,6 +515,7 @@ void stripTempDecorations(IRInst* inst) case kIROp_AutoDiffOriginalValueDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: case kIROp_PrimalValueStructKeyDecoration: + case kIROp_PrimalElementTypeDecoration: decor->removeAndDeallocate(); break; default: @@ -578,6 +589,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_TupleType: case kIROp_ArrayType: case kIROp_DifferentialPairType: + case kIROp_DifferentialPairUserCodeType: case kIROp_InterfaceType: case kIROp_AnyValueType: case kIROp_ClassType: @@ -832,6 +844,13 @@ struct AutoDiffPass : public InstPassBase if (!changed) break; + + // We have done transcribing the functions, now it is time to demote all DifferentialPair types + // and their operations down to DifferentialPairUserCodeType and *UserCode operations so they + // can be treated just like normal types with no special semantics in future processing, and won't + // be confused with the semantics of a DifferentialPair type during future autodiff code gen. + rewriteDifferentialPairToUserCode(module); + hasChanges |= changed; } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index f757375d8..e7a841323 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -159,7 +159,11 @@ struct DifferentiableTypeConformanceContext IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); - IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairType* diffPairType); + IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType); + + IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + + IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); // Lookup and return the 'Differential' type declared in the concrete type // in order to conform to the IDifferentiable interface. @@ -180,6 +184,13 @@ struct DifferentiableTypeConformanceContext diffElementType, as(origType)->getElementCount()); } + case kIROp_DifferentialPairUserCodeType: + { + auto diffPairType = as(origType); + auto diffType = getDiffTypeFromPairType(builder, diffPairType); + auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); + return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness); + } default: return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); } @@ -194,6 +205,8 @@ struct DifferentiableTypeConformanceContext case kIROp_FloatType: case kIROp_HalfType: case kIROp_DoubleType: + case kIROp_DifferentialPairType: + case kIROp_DifferentialPairUserCodeType: return true; case kIROp_VectorType: case kIROp_ArrayType: @@ -244,10 +257,6 @@ struct DifferentialPairTypeBuilder IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType); - IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type); - - IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type); - IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType); struct PairStructKey diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 14f6394e2..6f97ce076 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -73,16 +73,13 @@ public: bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { - if (level == DifferentiableLevel::Forward) + switch (func->getOp()) { - switch (func->getOp()) - { - case kIROp_ForwardDifferentiate: - case kIROp_BackwardDifferentiate: - return true; - default: - break; - } + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + return isDifferentiableFunc(func->getOperand(0), level); + default: + break; } func = getResolvedInstForDecorations(func); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 7411d031c..28c682c91 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -62,6 +62,9 @@ INST(Nop, nop, 0, 0) INST(OptionalType, Optional, 1, HOISTABLE) INST(DifferentialPairType, DiffPair, 1, HOISTABLE) + INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE) + INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPairUserCodeType) + INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) /* BindExistentialsTypeBase */ @@ -278,8 +281,16 @@ INST(undefined, undefined, 0, 0) INST(DefaultConstruct, defaultConstruct, 0, 0) INST(MakeDifferentialPair, MakeDiffPair, 2, 0) +INST(MakeDifferentialPairUserCode, MakeDiffPairUserCode, 2, 0) +INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPairUserCode) + INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) +INST(DifferentialPairGetDifferentialUserCode, GetDifferentialUserCode, 1, 0) +INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPairGetDifferentialUserCode) + INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) +INST(DifferentialPairGetPrimalUserCode, GetPrimalUserCode, 1, 0) +INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPairGetPrimalUserCode) INST(Specialize, specialize, 2, HOISTABLE) INST(LookupWitness, lookupWitness, 2, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index ae31219bd..f3181ecf7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2259,24 +2259,49 @@ struct IRGetTupleElement : IRInst // An Instruction that creates a differential pair value from a // primal and differential. -struct IRMakeDifferentialPair : IRInst + +struct IRMakeDifferentialPairBase : IRInst { - IR_LEAF_ISA(MakeDifferentialPair) + IR_PARENT_ISA(MakeDifferentialPairBase) IRInst* getPrimalValue() { return getOperand(0); } IRInst* getDifferentialValue() { return getOperand(1); } }; +struct IRMakeDifferentialPair : IRMakeDifferentialPairBase +{ + IR_LEAF_ISA(MakeDifferentialPair) +}; +struct IRMakeDifferentialPairUserCode : IRMakeDifferentialPairBase +{ + IR_LEAF_ISA(MakeDifferentialPairUserCode) +}; -struct IRDifferentialPairGetDifferential : IRInst +struct IRDifferentialPairGetDifferentialBase : IRInst { - IR_LEAF_ISA(DifferentialPairGetDifferential) + IR_PARENT_ISA(DifferentialPairGetDifferentialBase) IRInst* getBase() { return getOperand(0); } }; +struct IRDifferentialPairGetDifferential : IRDifferentialPairGetDifferentialBase +{ + IR_LEAF_ISA(DifferentialPairGetDifferential) +}; +struct IRDifferentialPairGetDifferentialUserCode : IRDifferentialPairGetDifferentialBase +{ + IR_LEAF_ISA(DifferentialPairGetDifferentialUserCode) +}; -struct IRDifferentialPairGetPrimal : IRInst +struct IRDifferentialPairGetPrimalBase : IRInst { - IR_LEAF_ISA(DifferentialPairGetPrimal) + IR_PARENT_ISA(DifferentialPairGetPrimalBase) IRInst* getBase() { return getOperand(0); } }; +struct IRDifferentialPairGetPrimal : IRDifferentialPairGetPrimalBase +{ + IR_LEAF_ISA(DifferentialPairGetPrimal) +}; +struct IRDifferentialPairGetPrimalUserCode : IRDifferentialPairGetPrimalBase +{ + IR_LEAF_ISA(DifferentialPairGetPrimalUserCode) +}; struct IRDetachDerivative : IRInst { @@ -2717,6 +2742,10 @@ public: IRType* valueType, IRInst* witnessTable); + IRDifferentialPairUserCodeType* getDifferentialPairUserCodeType( + IRType* valueType, + IRInst* witnessTable); + IRBackwardDiffIntermediateContextType* getBackwardDiffIntermediateContextType(IRInst* func); IRFuncType* getFuncType( @@ -2832,6 +2861,7 @@ public: IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); @@ -2966,6 +2996,8 @@ public: IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector( IRType* type, UInt argCount, diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 13920b011..254734965 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -174,6 +174,7 @@ bool isValueType(IRInst* dataType) case kIROp_ResultType: case kIROp_OptionalType: case kIROp_DifferentialPairType: + case kIROp_DifferentialPairUserCodeType: case kIROp_DynamicType: case kIROp_AnyValueType: case kIROp_ArrayType: diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2819a6d83..08c066f5d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2459,6 +2459,26 @@ namespace Slang if (found) { memoryArena.rewindToCursor(cursor); + + // If the found inst is defined in the same parent as current insert location but + // is located after the insert location, we need to move it to the insert location. + auto foundInst = *found; + if (foundInst->getParent() && foundInst->getParent() == getInsertLoc().getParent() && + getInsertLoc().getMode() == IRInsertLoc::Mode::Before) + { + auto insertLoc = getInsertLoc().getInst(); + bool isAfter = false; + for (auto cur = insertLoc->next; cur; cur = cur->next) + { + if (cur == foundInst) + { + isAfter = true; + break; + } + } + if (isAfter) + foundInst->insertBefore(insertLoc); + } return *found; } } @@ -2779,6 +2799,17 @@ namespace Slang operands); } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPairUserCodeType*)getType( + kIROp_DifferentialPairUserCodeType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType( IRInst* func) { @@ -3162,6 +3193,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as(type)); + SLANG_RELEASE_ASSERT(as(type)->getValueType() != nullptr); + + IRInst* args[] = { primal, differential }; + auto inst = createInstWithTrailingArgs( + this, kIROp_MakeDifferentialPairUserCode, type, 2, args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitSpecializeInst( IRType* type, IRInst* genericVal, @@ -3751,6 +3794,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPairGetDifferentialUserCode, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimalUserCode(IRInst* diffPair) + { + auto valueType = cast(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPairGetPrimalUserCode, + 1, + &diffPair); + } IRInst* IRBuilder::emitMakeMatrix( IRType* type, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index e22ea8a36..14a216fd2 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1448,14 +1448,24 @@ SIMPLE_IR_TYPE(TypeKind, Kind); // SIMPLE_IR_TYPE(GenericKind, Kind) -struct IRDifferentialPairType : IRType +struct IRDifferentialPairTypeBase : IRType { IRType* getValueType() { return (IRType*)getOperand(0); } IRInst* getWitness() { return (IRInst*)getOperand(1); } + IR_PARENT_ISA(DifferentialPairTypeBase) +}; + +struct IRDifferentialPairType : IRDifferentialPairTypeBase +{ IR_LEAF_ISA(DifferentialPairType) }; +struct IRDifferentialPairUserCodeType : IRDifferentialPairTypeBase +{ + IR_LEAF_ISA(DifferentialPairUserCodeType) +}; + struct IRBackwardDiffIntermediateContextType : IRType { IRInst* getFunc() { return getOperand(0); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d8912cbd4..5e6213205 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7285,10 +7285,18 @@ struct DeclLoweringVisitor : DeclVisitor builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount()); } - void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) + void lowerDerivativeMemberModifier(IRInst* inst, Decl* memberDecl, DerivativeMemberAttribute* derivativeMember) { - ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); - auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + IRInst* key = nullptr; + if (derivativeMember->memberDeclRef->declRef.getDecl() == memberDecl) + { + key = inst; + } + else + { + ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); + key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + } SLANG_RELEASE_ASSERT(as(key)); auto builder = getBuilder(); builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key); @@ -7358,7 +7366,7 @@ struct DeclLoweringVisitor : DeclVisitor } if (auto derivativeMemberModifier = fieldDecl->findModifier()) { - lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier); + lowerDerivativeMemberModifier(irFieldKey, fieldDecl, derivativeMemberModifier); } // We allow a field to be marked as a target intrinsic, diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 6076a41ca..470f5f983 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1232,6 +1232,25 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } break; } + + // Hard code implementation of T.Differential.Differential == T.Differential rule. + if (auto builtinReq = substDeclRef.getDecl()->findModifier()) + { + if (builtinReq->kind == BuiltinRequirementKind::DifferentialType) + { + // Is the concrete type a Differential associated type? + if (auto innerDeclRefType = as(thisSubst->witness->sub)) + { + if (auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier()) + { + if (innerBuiltinReq->kind == BuiltinRequirementKind::DifferentialType) + { + return innerDeclRefType; + } + } + } + } + } } } } -- cgit v1.2.3