summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-27 14:38:44 -0700
committerGitHub <noreply@github.com>2022-10-27 14:38:44 -0700
commitf9b1c565abbfc93bf2c8d4742f3db13e07db5e62 (patch)
tree4a4add20fd1db56df5984c20264389d4f23fc8f0 /source
parente6dc9a9eed58bdfd9c6f4016864acfe60381f927 (diff)
More renaming in jvp pass. (#2475)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-diff-call.cpp8
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp57
-rw-r--r--source/slang/slang-ir-diff-jvp.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h28
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--source/slang/slang-ir-specialize.cpp8
-rw-r--r--source/slang/slang-lower-to-ir.cpp6
9 files changed, 51 insertions, 68 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 896171f32..fcdee78ea 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -377,7 +377,7 @@ Result linkAndOptimizeIR(
// Process higher-order calles to auto-diff passes.
// 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass)
- processJVPDerivativeMarkers(irModule, sink);
+ processForwardDifferentiableFuncs(irModule, sink);
// 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass)
// processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 34e7e3de0..a574d6b7e 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -62,9 +62,9 @@ struct DerivativeCallProcessContext
//
if (auto origSpecialize = as<IRSpecialize>(origCallable))
{
- if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
{
- jvpCallable = jvpSpecRefDecorator->getJVPFunc();
+ jvpCallable = jvpSpecRefDecorator->getForwardDerivativeFunc();
}
}
@@ -73,9 +73,9 @@ struct DerivativeCallProcessContext
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
//
- if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpRefDecorator = origCallable->findDecoration<IRForwardDerivativeDecoration>())
{
- jvpCallable = jvpRefDecorator->getJVPFunc();
+ jvpCallable = jvpRefDecorator->getForwardDerivativeFunc();
}
SLANG_ASSERT(jvpCallable);
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 7e6fd30dd..73818dbb1 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1116,12 +1116,12 @@ struct JVPTranscriber
IRInst* diffCallee = nullptr;
- if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
{
// If the user has already provided an differentiated implementation, use that.
- diffCallee = derivativeReferenceDecor->getJVPFunc();
+ diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
}
- else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>())
{
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
@@ -1327,13 +1327,13 @@ struct JVPTranscriber
auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>())
{
- // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst.
+ // Look for an IRForwardDerivativeDecoration on the specialize inst.
// (Normally, this would be on the inner IRFunc, but in this case only the JVP func
// can be specialized, so we put a decoration on the IRSpecialize)
//
- if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
{
- auto jvpFunc = jvpFuncDecoration->getJVPFunc();
+ auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
// Make sure this isn't itself a specialize .
SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
@@ -1450,7 +1450,7 @@ struct JVPTranscriber
IRInst* origBase = originalInst->getOperand(0);
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
auto field = originalInst->getOperand(1);
- auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>();
+ auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
IRInst* primalOperands[] = { primalBase, field };
@@ -1957,8 +1957,8 @@ struct JVPDerivativeContext
IRInst* lookupJVPReference(IRInst* primalFunction)
{
- if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
- return jvpDefinition->getJVPFunc();
+ if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
+ return jvpDefinition->getForwardDerivativeFunc();
return nullptr;
}
@@ -2010,13 +2010,13 @@ struct JVPDerivativeContext
//
if (lookupJVPReference(baseFunction)) continue;
- if (isMarkedForJVP(baseFunction))
+ if (isMarkedForForwardDifferentiation(baseFunction))
{
if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
{
IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction);
SLANG_ASSERT(diffFunc);
- builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc);
+ builder->addForwardDerivativeDecoration(baseFunction, diffFunc);
workQueue->push(diffFunc);
}
else
@@ -2210,15 +2210,15 @@ struct JVPDerivativeContext
}
// Checks decorators to see if the function should
- // be differentiated (kIROp_JVPDerivativeMarkerDecoration)
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
//
- bool isMarkedForJVP(IRGlobalValueWithCode* callable)
+ bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable)
{
for(auto decoration = callable->getFirstDecoration();
decoration;
decoration = decoration->getNextDecoration())
{
- if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
+ if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration)
{
return true;
}
@@ -2226,24 +2226,7 @@ struct JVPDerivativeContext
return false;
}
- // Removes the JVPDerivativeMarkerDecoration from the provided callable,
- // if it exists.
- //
- void unmarkForJVP(IRGlobalValueWithCode* callable)
- {
- for(auto decoration = callable->getFirstDecoration();
- decoration;
- decoration = decoration->getNextDecoration())
- {
- if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
- {
- decoration->removeAndDeallocate();
- return;
- }
- }
- }
-
- IRStringLit* getJVPFuncName(IRBuilder* builder,
+ IRStringLit* getForwardDerivativeFuncName(IRBuilder* builder,
IRInst* func)
{
auto oldLoc = builder->getInsertLoc();
@@ -2252,11 +2235,11 @@ struct JVPDerivativeContext
IRStringLit* name = nullptr;
if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
{
- name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice());
+ name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice());
}
else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
{
- name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice());
+ name = builder->getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice());
}
builder->setInsertLoc(oldLoc);
@@ -2309,7 +2292,7 @@ struct JVPDerivativeContext
// Set up context and call main process method.
//
-bool processJVPDerivativeMarkers(
+bool processForwardDifferentiableFuncs(
IRModule* module,
DiagnosticSink* sink,
IRJVPDerivativePassOptions const&)
@@ -2335,8 +2318,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
auto next = decor->getNextDecoration();
switch (decor->getOp())
{
- case kIROp_JVPDerivativeReferenceDecoration:
- case kIROp_JVPDerivativeMemberReferenceDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_DerivativeMemberDecoration:
decor->removeAndDeallocate();
break;
default:
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
index 8ab4e0e8f..01ac15d6c 100644
--- a/source/slang/slang-ir-diff-jvp.h
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -13,7 +13,7 @@ namespace Slang
// Nothing for now..
};
- bool processJVPDerivativeMarkers(
+ bool processForwardDifferentiableFuncs(
IRModule* module,
DiagnosticSink* sink,
IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index ccde80476..1d1db14f9 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -700,15 +700,15 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0)
/// Decorated function is marked for the forward-mode differentiation pass.
- INST(JVPDerivativeMarkerDecoration, differentiateJvp, 0, 0)
+ INST(ForwardDifferentiableDecoration, forwardDifferentiable, 0, 0)
/// Used by the auto-diff pass to hold a reference to the
/// generated derivative function.
- INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0)
+ INST(ForwardDerivativeDecoration, jvpFnReference, 1, 0)
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
- INST(JVPDerivativeMemberReferenceDecoration, derivativeMemberDecoration, 1, 0)
+ INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
/// Marks a class type as a COM interface implementation, which enables
/// the witness table to be easily picked up by emit.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 95202d9d0..deb81134b 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -546,33 +546,33 @@ struct IRSequentialIDDecoration : IRDecoration
IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); }
};
-struct IRJVPDerivativeMarkerDecoration : IRDecoration
+struct IRForwardDifferentiableDecoration : IRDecoration
{
enum
{
- kOp = kIROp_JVPDerivativeMarkerDecoration
+ kOp = kIROp_ForwardDifferentiableDecoration
};
- IR_LEAF_ISA(JVPDerivativeMarkerDecoration)
+ IR_LEAF_ISA(ForwardDifferentiableDecoration)
};
-struct IRJVPDerivativeReferenceDecoration : IRDecoration
+struct IRForwardDerivativeDecoration : IRDecoration
{
enum
{
- kOp = kIROp_JVPDerivativeReferenceDecoration
+ kOp = kIROp_ForwardDerivativeDecoration
};
- IR_LEAF_ISA(JVPDerivativeReferenceDecoration)
+ IR_LEAF_ISA(ForwardDerivativeDecoration)
- IRInst* getJVPFunc() { return getOperand(0); }
+ IRInst* getForwardDerivativeFunc() { return getOperand(0); }
};
-struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration
+struct IRDerivativeMemberDecoration : IRDecoration
{
enum
{
- kOp = kIROp_JVPDerivativeMemberReferenceDecoration
+ kOp = kIROp_DerivativeMemberDecoration
};
- IR_LEAF_ISA(JVPDerivativeMemberReferenceDecoration)
+ IR_LEAF_ISA(DerivativeMemberDecoration)
IRInst* getDerivativeMemberStructKey() { return getOperand(0); }
};
@@ -3206,14 +3206,14 @@ public:
addDecoration(value, kIROp_ForceInlineDecoration);
}
- void addJVPDerivativeMarkerDecoration(IRInst* value)
+ void addForwardDifferentiableDecoration(IRInst* value)
{
- addDecoration(value, kIROp_JVPDerivativeMarkerDecoration);
+ addDecoration(value, kIROp_ForwardDifferentiableDecoration);
}
- void addJVPDerivativeReferenceDecoration(IRInst* value, IRInst* jvpFn)
+ void addForwardDerivativeDecoration(IRInst* value, IRInst* jvpFn)
{
- addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn);
+ addDecoration(value, kIROp_ForwardDerivativeDecoration, jvpFn);
}
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 56688abae..eb899b69c 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -448,7 +448,7 @@ static void cloneExtraDecorations(
case kIROp_LayoutDecoration:
case kIROp_PublicDecoration:
case kIROp_SequentialIDDecoration:
- case kIROp_JVPDerivativeReferenceDecoration:
+ case kIROp_ForwardDerivativeDecoration:
if(!clonedInst->findDecorationImpl(decoration->getOp()))
{
cloneInst(context, builder, decoration);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 53ea99a0c..406e5157c 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -388,15 +388,15 @@ struct SpecializationContext
auto genericReturnVal = findInnerMostGenericReturnVal(genericVal);
if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>())
{
- if (auto customDiffRef = genericReturnVal->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto customDiffRef = genericReturnVal->findDecoration<IRForwardDerivativeDecoration>())
{
// If we already have a diff func on this specialize, skip.
- if (auto specDiffRef = specInst->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto specDiffRef = specInst->findDecoration<IRForwardDerivativeDecoration>())
{
return false;
}
- auto specDiffFunc = as<IRSpecialize>(customDiffRef->getJVPFunc());
+ auto specDiffFunc = as<IRSpecialize>(customDiffRef->getForwardDerivativeFunc());
// If the base is specialized, the JVP version must be also be a specialized
// generic.
@@ -436,7 +436,7 @@ struct SpecializationContext
addToWorkList(newDiffFuncType);
addToWorkList(newDiffFunc);
- builder.addJVPDerivativeReferenceDecoration(specInst, newDiffFunc);
+ builder.addForwardDerivativeDecoration(specInst, newDiffFunc);
return true;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index acb7869e0..ae0590105 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7087,7 +7087,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val;
SLANG_RELEASE_ASSERT(as<IRStructKey>(key));
auto builder = getBuilder();
- builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key);
+ builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key);
}
LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl)
@@ -7807,7 +7807,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (decl->findModifier<ForwardDifferentiableAttribute>())
{
- getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
}
// Always force inline diff setter accessor to prevent downstream compiler from complaining
@@ -8222,7 +8222,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
IRInst* jvpFunc = loweredVal.val;
- getBuilder()->addDecoration(irFunc, kIROp_JVPDerivativeReferenceDecoration, jvpFunc);
+ getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc);
// Reset cursor.
subContext->irBuilder->setInsertInto(irFunc);