summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-26 22:21:29 -0400
committerGitHub <noreply@github.com>2022-10-26 19:21:29 -0700
commitf7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch)
tree574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-ir-diff-jvp.cpp
parent939be44ca23476e622dfb24a592383fe2a1da61f (diff)
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp282
1 files changed, 162 insertions, 120 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index b97556ab1..1a86506b3 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1091,100 +1091,106 @@ struct JVPTranscriber
// in the current transcription context.
//
InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
- {
- if (as<IRFunc>(origCall->getCallee()))
- {
- auto origCallee = origCall->getCallee();
+ {
+
+ IRInst* origCallee = origCall->getCallee();
- // Since concrete functions are globals, the primal callee is the same
- // as the original callee.
+ if (!origCallee)
+ {
+ // Note that this can only happen if the callee is a result
+ // of a higher-order operation. For now, we assume that we cannot
+ // differentiate such calls safely.
+ // TODO(sai): Should probably get checked in the front-end.
//
- auto primalCallee = origCallee;
+ getSink()->diagnose(origCall->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "attempting to differentiate unresolved callee");
+
+ return InstPair(nullptr, nullptr);
+ }
- IRInst* diffCallee = nullptr;
+ // Since concrete functions are globals, the primal callee is the same
+ // as the original callee.
+ //
+ auto primalCallee = origCallee;
- if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
- {
- // If the user has already provided an differentiated implementation, use that.
- diffCallee = derivativeReferenceDecor->getJVPFunc();
- }
- else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ IRInst* diffCallee = nullptr;
+
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ // If the user has already provided an differentiated implementation, use that.
+ diffCallee = derivativeReferenceDecor->getJVPFunc();
+ }
+ else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ {
+ // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
+ // to generate the implementation.
+ diffCallee = builder->emitJVPDifferentiateInst(
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
+ }
+ else
+ {
+ // The callee is non differentiable, just return primal value with null diff value.
+ IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ return InstPair(primalCall, nullptr);
+ }
+
+ List<IRInst*> args;
+ // Go over the parameter list and create pairs for each input (if required)
+ for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
+ {
+ auto origArg = origCall->getArg(ii);
+ auto primalArg = findOrTranscribePrimalInst(builder, origArg);
+ SLANG_ASSERT(primalArg);
+
+ auto primalType = primalArg->getDataType();
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
+
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+
+ if (auto pairType = tryGetDiffPairType(builder, primalType))
{
- // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
- // to generate the implementation.
- diffCallee = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
- primalCallee);
+ // If a pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffArg);
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
+ args.add(diffPair);
}
else
{
- // The callee is non differentiable, just return primal value with null diff value.
- IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
- return InstPair(primalCall, nullptr);
+ // Add original/primal argument.
+ args.add(primalArg);
}
+ }
+
+ IRType* diffReturnType = nullptr;
+ diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
- List<IRInst*> args;
- // Go over the parameter list and create pairs for each input (if required)
- for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
- {
- auto origArg = origCall->getArg(ii);
- auto primalArg = findOrTranscribePrimalInst(builder, origArg);
- SLANG_ASSERT(primalArg);
-
- auto primalType = primalArg->getDataType();
- auto diffArg = findOrTranscribeDiffInst(builder, origArg);
-
- if (!diffArg)
- diffArg = getDifferentialZeroOfType(builder, primalType);
-
- if (auto pairType = tryGetDiffPairType(builder, primalType))
- {
- // If a pair type can be formed, this must be non-null.
- SLANG_RELEASE_ASSERT(diffArg);
- auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
- args.add(diffPair);
- }
- else
- {
- // Add original/primal argument.
- args.add(primalArg);
- }
- }
-
- IRType* diffReturnType = nullptr;
- diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
- SLANG_ASSERT(diffReturnType);
+ if (!diffReturnType)
+ {
+ SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
+ diffReturnType = builder->getVoidType();
+ }
- auto callInst = builder->emitCallInst(
- diffReturnType,
- diffCallee,
- args);
+ auto callInst = builder->emitCallInst(
+ diffReturnType,
+ diffCallee,
+ args);
+ if (diffReturnType->getOp() != kIROp_VoidType)
+ {
IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst);
IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst);
-
return InstPair(primalResultValue, diffResultValue);
}
- else if(as<IRSpecialize>(origCall->getCallee()) ||
- as<IRLookupWitnessMethod>(origCall->getCallee()))
- {
- getSink()->diagnose(origCall->sourceLoc,
- Diagnostics::unimplemented,
- "attempting to differentiate unspecialized callee or an interface method");
- }
else
{
- // Note that this can only happen if the callee is a result
- // of a higher-order operation. For now, we assume that we cannot
- // differentiate such calls safely.
- // TODO(sai): Should probably get checked in the front-end.
- //
- getSink()->diagnose(origCall->sourceLoc,
- Diagnostics::internalCompilerError,
- "attempting to differentiate unresolved callee");
+ // Return the inst itself if the return value is void.
+ // This is fine since these values should never actually be used anywhere.
+ //
+ return InstPair(callInst, callInst);
}
-
- return InstPair(nullptr, nullptr);
}
InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
@@ -1314,17 +1320,35 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
- InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+ InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize)
{
- // This is slightly counter-intuitive, but we don't perform any differentiation
- // logic here. We simple clone the original specialize which points to the original function,
- // or the cloned version in case we're inside a generic scope.
- // The differentiation logic is inserted later when this is used in an IRCall.
- // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn))
- // rather than have Specialize(JVPDifferentiate(Fn))
- //
- auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize);
- return InstPair(diffSpecialize, diffSpecialize);
+ // In general, we should not see any specialize insts at this stage.
+ // The exceptions are target intrinsics.
+ auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+ if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>())
+ {
+ // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst.
+ // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
+ // can be specialized, so we put a decoration on the IRSpecialize)
+ //
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ auto jvpFunc = jvpFuncDecoration->getJVPFunc();
+
+ // Make sure this isn't itself a specialize .
+ SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
+
+ return InstPair(jvpFunc, jvpFunc);
+ }
+ }
+ else
+ {
+ getSink()->diagnose(origSpecialize->sourceLoc,
+ Diagnostics::unexpected,
+ "should not be attempting to differentiate anything specialized here.");
+ }
+
+ return InstPair(nullptr, nullptr);
}
InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup)
@@ -1777,10 +1801,7 @@ struct JVPTranscriber
return transcribeConst(builder, origInst);
case kIROp_Specialize:
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::unexpected,
- "should not be attempting to differentiate anything specialized here.");
- return InstPair(nullptr, nullptr);
+ return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
case kIROp_lookup_interface_method:
return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
@@ -1972,7 +1993,10 @@ struct JVPDerivativeContext
if (auto specializeInst = as<IRSpecialize>(baseInst))
{
- baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase());
+ // Certain specialize insts come with a derivative
+ // reference attached. Skip such instructions.
+ //
+ if (lookupJVPReference(specializeInst)) continue;
}
else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst))
{
@@ -2026,6 +2050,12 @@ struct JVPDerivativeContext
{
builder->setInsertBefore(pairType);
+ if (!as<IRType>(pairType->getValueType()))
+ {
+ // Do not handle non-concrete types.
+ return nullptr;
+ }
+
auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
builder,
pairType->getValueType());
@@ -2054,19 +2084,20 @@ struct JVPDerivativeContext
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
- auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext);
-
- builder->setInsertBefore(makePairInst);
-
- List<IRInst*> operands;
- operands.add(makePairInst->getPrimalValue());
- operands.add(makePairInst->getDifferentialValue());
+ if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext))
+ {
+ builder->setInsertBefore(makePairInst);
+
+ List<IRInst*> operands;
+ operands.add(makePairInst->getPrimalValue());
+ operands.add(makePairInst->getDifferentialValue());
- auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
- makePairInst->replaceUsesWith(makeStructInst);
- makePairInst->removeAndDeallocate();
+ auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
+ makePairInst->replaceUsesWith(makeStructInst);
+ makePairInst->removeAndDeallocate();
- return makeStructInst;
+ return makeStructInst;
+ }
}
return nullptr;
@@ -2074,30 +2105,30 @@ struct JVPDerivativeContext
IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
{
-
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext);
-
- builder->setInsertBefore(getDiffInst);
-
- auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
- getDiffInst->replaceUsesWith(diffFieldExtract);
- getDiffInst->removeAndDeallocate();
-
- return diffFieldExtract;
+ if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext))
+ {
+ builder->setInsertBefore(getDiffInst);
+
+ auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ getDiffInst->replaceUsesWith(diffFieldExtract);
+ getDiffInst->removeAndDeallocate();
+ return diffFieldExtract;
+ }
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext);
-
- builder->setInsertBefore(getPrimalInst);
+ if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext))
+ {
+ builder->setInsertBefore(getPrimalInst);
- auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
- getPrimalInst->replaceUsesWith(primalFieldExtract);
- getPrimalInst->removeAndDeallocate();
+ auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ getPrimalInst->replaceUsesWith(primalFieldExtract);
+ getPrimalInst->removeAndDeallocate();
- return primalFieldExtract;
+ return primalFieldExtract;
+ }
}
return nullptr;
@@ -2295,9 +2326,9 @@ bool processJVPDerivativeMarkers(
return changed;
}
-void stripAutoDiffDecorations(IRModule* module)
+void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
- for (auto inst : module->getGlobalInsts())
+ for (auto inst : parent->getChildren())
{
for (auto decor = inst->getFirstDecoration(); decor; )
{
@@ -2313,7 +2344,18 @@ void stripAutoDiffDecorations(IRModule* module)
}
decor = next;
}
+
+ if (inst->getFirstChild() != nullptr)
+ {
+ stripAutoDiffDecorationsFromChildren(inst);
+ }
}
}
+void stripAutoDiffDecorations(IRModule* module)
+{
+ stripAutoDiffDecorationsFromChildren(module->getModuleInst());
+}
+
+
}