diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-26 22:21:29 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 19:21:29 -0700 |
| commit | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch) | |
| tree | 574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-ir-diff-jvp.cpp | |
| parent | 939be44ca23476e622dfb24a592383fe2a1da61f (diff) | |
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 282 |
1 files changed, 162 insertions, 120 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index b97556ab1..1a86506b3 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1091,100 +1091,106 @@ struct JVPTranscriber // in the current transcription context. // InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) - { - if (as<IRFunc>(origCall->getCallee())) - { - auto origCallee = origCall->getCallee(); + { + + IRInst* origCallee = origCall->getCallee(); - // Since concrete functions are globals, the primal callee is the same - // as the original callee. + if (!origCallee) + { + // Note that this can only happen if the callee is a result + // of a higher-order operation. For now, we assume that we cannot + // differentiate such calls safely. + // TODO(sai): Should probably get checked in the front-end. // - auto primalCallee = origCallee; + getSink()->diagnose(origCall->sourceLoc, + Diagnostics::internalCompilerError, + "attempting to differentiate unresolved callee"); + + return InstPair(nullptr, nullptr); + } - IRInst* diffCallee = nullptr; + // Since concrete functions are globals, the primal callee is the same + // as the original callee. + // + auto primalCallee = origCallee; - if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) - { - // If the user has already provided an differentiated implementation, use that. - diffCallee = derivativeReferenceDecor->getJVPFunc(); - } - else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + IRInst* diffCallee = nullptr; + + if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + // If the user has already provided an differentiated implementation, use that. + diffCallee = derivativeReferenceDecor->getJVPFunc(); + } + else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + { + // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass + // to generate the implementation. + diffCallee = builder->emitJVPDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); + } + else + { + // The callee is non differentiable, just return primal value with null diff value. + IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + return InstPair(primalCall, nullptr); + } + + List<IRInst*> args; + // Go over the parameter list and create pairs for each input (if required) + for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) + { + auto origArg = origCall->getArg(ii); + auto primalArg = findOrTranscribePrimalInst(builder, origArg); + SLANG_ASSERT(primalArg); + + auto primalType = primalArg->getDataType(); + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + + if (auto pairType = tryGetDiffPairType(builder, primalType)) { - // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass - // to generate the implementation. - diffCallee = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), - primalCallee); + // If a pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffArg); + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); + args.add(diffPair); } else { - // The callee is non differentiable, just return primal value with null diff value. - IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); - return InstPair(primalCall, nullptr); + // Add original/primal argument. + args.add(primalArg); } + } + + IRType* diffReturnType = nullptr; + diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); - List<IRInst*> args; - // Go over the parameter list and create pairs for each input (if required) - for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) - { - auto origArg = origCall->getArg(ii); - auto primalArg = findOrTranscribePrimalInst(builder, origArg); - SLANG_ASSERT(primalArg); - - auto primalType = primalArg->getDataType(); - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - - if (auto pairType = tryGetDiffPairType(builder, primalType)) - { - // If a pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - args.add(diffPair); - } - else - { - // Add original/primal argument. - args.add(primalArg); - } - } - - IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); - SLANG_ASSERT(diffReturnType); + if (!diffReturnType) + { + SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); + diffReturnType = builder->getVoidType(); + } - auto callInst = builder->emitCallInst( - diffReturnType, - diffCallee, - args); + auto callInst = builder->emitCallInst( + diffReturnType, + diffCallee, + args); + if (diffReturnType->getOp() != kIROp_VoidType) + { IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst); IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst); - return InstPair(primalResultValue, diffResultValue); } - else if(as<IRSpecialize>(origCall->getCallee()) || - as<IRLookupWitnessMethod>(origCall->getCallee())) - { - getSink()->diagnose(origCall->sourceLoc, - Diagnostics::unimplemented, - "attempting to differentiate unspecialized callee or an interface method"); - } else { - // Note that this can only happen if the callee is a result - // of a higher-order operation. For now, we assume that we cannot - // differentiate such calls safely. - // TODO(sai): Should probably get checked in the front-end. - // - getSink()->diagnose(origCall->sourceLoc, - Diagnostics::internalCompilerError, - "attempting to differentiate unresolved callee"); + // Return the inst itself if the return value is void. + // This is fine since these values should never actually be used anywhere. + // + return InstPair(callInst, callInst); } - - return InstPair(nullptr, nullptr); } InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) @@ -1314,17 +1320,35 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) + InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize) { - // This is slightly counter-intuitive, but we don't perform any differentiation - // logic here. We simple clone the original specialize which points to the original function, - // or the cloned version in case we're inside a generic scope. - // The differentiation logic is inserted later when this is used in an IRCall. - // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn)) - // rather than have Specialize(JVPDifferentiate(Fn)) - // - auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize); - return InstPair(diffSpecialize, diffSpecialize); + // In general, we should not see any specialize insts at this stage. + // The exceptions are target intrinsics. + auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); + if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>()) + { + // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst. + // (Normally, this would be on the inner IRFunc, but in this case only the JVP func + // can be specialized, so we put a decoration on the IRSpecialize) + // + if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + auto jvpFunc = jvpFuncDecoration->getJVPFunc(); + + // Make sure this isn't itself a specialize . + SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); + + return InstPair(jvpFunc, jvpFunc); + } + } + else + { + getSink()->diagnose(origSpecialize->sourceLoc, + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); + } + + return InstPair(nullptr, nullptr); } InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) @@ -1777,10 +1801,7 @@ struct JVPTranscriber return transcribeConst(builder, origInst); case kIROp_Specialize: - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::unexpected, - "should not be attempting to differentiate anything specialized here."); - return InstPair(nullptr, nullptr); + return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); case kIROp_lookup_interface_method: return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); @@ -1972,7 +1993,10 @@ struct JVPDerivativeContext if (auto specializeInst = as<IRSpecialize>(baseInst)) { - baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase()); + // Certain specialize insts come with a derivative + // reference attached. Skip such instructions. + // + if (lookupJVPReference(specializeInst)) continue; } else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst)) { @@ -2026,6 +2050,12 @@ struct JVPDerivativeContext { builder->setInsertBefore(pairType); + if (!as<IRType>(pairType->getValueType())) + { + // Do not handle non-concrete types. + return nullptr; + } + auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( builder, pairType->getValueType()); @@ -2054,19 +2084,20 @@ struct JVPDerivativeContext if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { - auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext); - - builder->setInsertBefore(makePairInst); - - List<IRInst*> operands; - operands.add(makePairInst->getPrimalValue()); - operands.add(makePairInst->getDifferentialValue()); + if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext)) + { + builder->setInsertBefore(makePairInst); + + List<IRInst*> operands; + operands.add(makePairInst->getPrimalValue()); + operands.add(makePairInst->getDifferentialValue()); - auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); - makePairInst->replaceUsesWith(makeStructInst); - makePairInst->removeAndDeallocate(); + auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); + makePairInst->replaceUsesWith(makeStructInst); + makePairInst->removeAndDeallocate(); - return makeStructInst; + return makeStructInst; + } } return nullptr; @@ -2074,30 +2105,30 @@ struct JVPDerivativeContext IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) { - if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext); - - builder->setInsertBefore(getDiffInst); - - auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); - getDiffInst->replaceUsesWith(diffFieldExtract); - getDiffInst->removeAndDeallocate(); - - return diffFieldExtract; + if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext)) + { + builder->setInsertBefore(getDiffInst); + + auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); + getDiffInst->replaceUsesWith(diffFieldExtract); + getDiffInst->removeAndDeallocate(); + return diffFieldExtract; + } } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext); - - builder->setInsertBefore(getPrimalInst); + if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext)) + { + builder->setInsertBefore(getPrimalInst); - auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); - getPrimalInst->replaceUsesWith(primalFieldExtract); - getPrimalInst->removeAndDeallocate(); + auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + getPrimalInst->replaceUsesWith(primalFieldExtract); + getPrimalInst->removeAndDeallocate(); - return primalFieldExtract; + return primalFieldExtract; + } } return nullptr; @@ -2295,9 +2326,9 @@ bool processJVPDerivativeMarkers( return changed; } -void stripAutoDiffDecorations(IRModule* module) +void stripAutoDiffDecorationsFromChildren(IRInst* parent) { - for (auto inst : module->getGlobalInsts()) + for (auto inst : parent->getChildren()) { for (auto decor = inst->getFirstDecoration(); decor; ) { @@ -2313,7 +2344,18 @@ void stripAutoDiffDecorations(IRModule* module) } decor = next; } + + if (inst->getFirstChild() != nullptr) + { + stripAutoDiffDecorationsFromChildren(inst); + } } } +void stripAutoDiffDecorations(IRModule* module) +{ + stripAutoDiffDecorationsFromChildren(module->getModuleInst()); +} + + } |
