diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-27 14:38:44 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-27 14:38:44 -0700 |
| commit | f9b1c565abbfc93bf2c8d4742f3db13e07db5e62 (patch) | |
| tree | 4a4add20fd1db56df5984c20264389d4f23fc8f0 /source/slang/slang-ir-diff-jvp.cpp | |
| parent | e6dc9a9eed58bdfd9c6f4016864acfe60381f927 (diff) | |
More renaming in jvp pass. (#2475)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 57 |
1 files changed, 20 insertions, 37 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 7e6fd30dd..73818dbb1 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1116,12 +1116,12 @@ struct JVPTranscriber IRInst* diffCallee = nullptr; - if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) { // If the user has already provided an differentiated implementation, use that. - diffCallee = derivativeReferenceDecor->getJVPFunc(); + diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); } - else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>()) { // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. @@ -1327,13 +1327,13 @@ struct JVPTranscriber auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>()) { - // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst. + // Look for an IRForwardDerivativeDecoration on the specialize inst. // (Normally, this would be on the inner IRFunc, but in this case only the JVP func // can be specialized, so we put a decoration on the IRSpecialize) // - if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) { - auto jvpFunc = jvpFuncDecoration->getJVPFunc(); + auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); @@ -1450,7 +1450,7 @@ struct JVPTranscriber IRInst* origBase = originalInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto field = originalInst->getOperand(1); - auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>(); + auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); IRInst* primalOperands[] = { primalBase, field }; @@ -1957,8 +1957,8 @@ struct JVPDerivativeContext IRInst* lookupJVPReference(IRInst* primalFunction) { - if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>()) - return jvpDefinition->getJVPFunc(); + if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) + return jvpDefinition->getForwardDerivativeFunc(); return nullptr; } @@ -2010,13 +2010,13 @@ struct JVPDerivativeContext // if (lookupJVPReference(baseFunction)) continue; - if (isMarkedForJVP(baseFunction)) + if (isMarkedForForwardDifferentiation(baseFunction)) { if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) { IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction); SLANG_ASSERT(diffFunc); - builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc); + builder->addForwardDerivativeDecoration(baseFunction, diffFunc); workQueue->push(diffFunc); } else @@ -2210,15 +2210,15 @@ struct JVPDerivativeContext } // Checks decorators to see if the function should - // be differentiated (kIROp_JVPDerivativeMarkerDecoration) + // be differentiated (kIROp_ForwardDifferentiableDecoration) // - bool isMarkedForJVP(IRGlobalValueWithCode* callable) + bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable) { for(auto decoration = callable->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration()) { - if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) + if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration) { return true; } @@ -2226,24 +2226,7 @@ struct JVPDerivativeContext return false; } - // Removes the JVPDerivativeMarkerDecoration from the provided callable, - // if it exists. - // - void unmarkForJVP(IRGlobalValueWithCode* callable) - { - for(auto decoration = callable->getFirstDecoration(); - decoration; - decoration = decoration->getNextDecoration()) - { - if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) - { - decoration->removeAndDeallocate(); - return; - } - } - } - - IRStringLit* getJVPFuncName(IRBuilder* builder, + IRStringLit* getForwardDerivativeFuncName(IRBuilder* builder, IRInst* func) { auto oldLoc = builder->getInsertLoc(); @@ -2252,11 +2235,11 @@ struct JVPDerivativeContext IRStringLit* name = nullptr; if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) { - name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice()); + name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); } else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) { - name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice()); + name = builder->getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); } builder->setInsertLoc(oldLoc); @@ -2309,7 +2292,7 @@ struct JVPDerivativeContext // Set up context and call main process method. // -bool processJVPDerivativeMarkers( +bool processForwardDifferentiableFuncs( IRModule* module, DiagnosticSink* sink, IRJVPDerivativePassOptions const&) @@ -2335,8 +2318,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) auto next = decor->getNextDecoration(); switch (decor->getOp()) { - case kIROp_JVPDerivativeReferenceDecoration: - case kIROp_JVPDerivativeMemberReferenceDecoration: + case kIROp_ForwardDerivativeDecoration: + case kIROp_DerivativeMemberDecoration: decor->removeAndDeallocate(); break; default: |
