summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp57
1 files changed, 20 insertions, 37 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 7e6fd30dd..73818dbb1 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1116,12 +1116,12 @@ struct JVPTranscriber
IRInst* diffCallee = nullptr;
- if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
{
// If the user has already provided an differentiated implementation, use that.
- diffCallee = derivativeReferenceDecor->getJVPFunc();
+ diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
}
- else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>())
{
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
@@ -1327,13 +1327,13 @@ struct JVPTranscriber
auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>())
{
- // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst.
+ // Look for an IRForwardDerivativeDecoration on the specialize inst.
// (Normally, this would be on the inner IRFunc, but in this case only the JVP func
// can be specialized, so we put a decoration on the IRSpecialize)
//
- if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
{
- auto jvpFunc = jvpFuncDecoration->getJVPFunc();
+ auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
// Make sure this isn't itself a specialize .
SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
@@ -1450,7 +1450,7 @@ struct JVPTranscriber
IRInst* origBase = originalInst->getOperand(0);
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
auto field = originalInst->getOperand(1);
- auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>();
+ auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
IRInst* primalOperands[] = { primalBase, field };
@@ -1957,8 +1957,8 @@ struct JVPDerivativeContext
IRInst* lookupJVPReference(IRInst* primalFunction)
{
- if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
- return jvpDefinition->getJVPFunc();
+ if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
+ return jvpDefinition->getForwardDerivativeFunc();
return nullptr;
}
@@ -2010,13 +2010,13 @@ struct JVPDerivativeContext
//
if (lookupJVPReference(baseFunction)) continue;
- if (isMarkedForJVP(baseFunction))
+ if (isMarkedForForwardDifferentiation(baseFunction))
{
if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
{
IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction);
SLANG_ASSERT(diffFunc);
- builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc);
+ builder->addForwardDerivativeDecoration(baseFunction, diffFunc);
workQueue->push(diffFunc);
}
else
@@ -2210,15 +2210,15 @@ struct JVPDerivativeContext
}
// Checks decorators to see if the function should
- // be differentiated (kIROp_JVPDerivativeMarkerDecoration)
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
//
- bool isMarkedForJVP(IRGlobalValueWithCode* callable)
+ bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable)
{
for(auto decoration = callable->getFirstDecoration();
decoration;
decoration = decoration->getNextDecoration())
{
- if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
+ if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration)
{
return true;
}
@@ -2226,24 +2226,7 @@ struct JVPDerivativeContext
return false;
}
- // Removes the JVPDerivativeMarkerDecoration from the provided callable,
- // if it exists.
- //
- void unmarkForJVP(IRGlobalValueWithCode* callable)
- {
- for(auto decoration = callable->getFirstDecoration();
- decoration;
- decoration = decoration->getNextDecoration())
- {
- if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
- {
- decoration->removeAndDeallocate();
- return;
- }
- }
- }
-
- IRStringLit* getJVPFuncName(IRBuilder* builder,
+ IRStringLit* getForwardDerivativeFuncName(IRBuilder* builder,
IRInst* func)
{
auto oldLoc = builder->getInsertLoc();
@@ -2252,11 +2235,11 @@ struct JVPDerivativeContext
IRStringLit* name = nullptr;
if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
{
- name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice());
+ name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice());
}
else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
{
- name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice());
+ name = builder->getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice());
}
builder->setInsertLoc(oldLoc);
@@ -2309,7 +2292,7 @@ struct JVPDerivativeContext
// Set up context and call main process method.
//
-bool processJVPDerivativeMarkers(
+bool processForwardDifferentiableFuncs(
IRModule* module,
DiagnosticSink* sink,
IRJVPDerivativePassOptions const&)
@@ -2335,8 +2318,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
auto next = decor->getNextDecoration();
switch (decor->getOp())
{
- case kIROp_JVPDerivativeReferenceDecoration:
- case kIROp_JVPDerivativeMemberReferenceDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_DerivativeMemberDecoration:
decor->removeAndDeallocate();
break;
default: