summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp229
1 files changed, 115 insertions, 114 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 843428c01..b97556ab1 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -115,7 +115,7 @@ struct DifferentiableTypeConformanceContext
IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
{
- if (auto conformance = lookUpConformanceForType(builder, origType))
+ if (auto conformance = lookUpConformanceForType(builder, origType))
{
if (auto witnessTable = as<IRWitnessTable>(conformance))
{
@@ -144,6 +144,14 @@ struct DifferentiableTypeConformanceContext
//
IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
{
+ switch (origType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
+ }
return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey);
}
@@ -1083,8 +1091,7 @@ struct JVPTranscriber
// in the current transcription context.
//
InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
- {
-
+ {
if (as<IRFunc>(origCall->getCallee()))
{
auto origCallee = origCall->getCallee();
@@ -1094,12 +1101,28 @@ struct JVPTranscriber
//
auto primalCallee = origCallee;
- // TODO: If inner is not differentiable, treat as non-differentiable call.
- // Build the differential callee
- IRInst* diffCall = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
- primalCallee);
-
+ IRInst* diffCallee = nullptr;
+
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ // If the user has already provided an differentiated implementation, use that.
+ diffCallee = derivativeReferenceDecor->getJVPFunc();
+ }
+ else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ {
+ // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
+ // to generate the implementation.
+ diffCallee = builder->emitJVPDifferentiateInst(
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
+ }
+ else
+ {
+ // The callee is non differentiable, just return primal value with null diff value.
+ IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ return InstPair(primalCall, nullptr);
+ }
+
List<IRInst*> args;
// Go over the parameter list and create pairs for each input (if required)
for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
@@ -1109,18 +1132,16 @@ struct JVPTranscriber
SLANG_ASSERT(primalArg);
auto primalType = primalArg->getDataType();
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
+
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+
if (auto pairType = tryGetDiffPairType(builder, primalType))
{
- auto diffArg = findOrTranscribeDiffInst(builder, origArg);
-
- if (!diffArg)
- diffArg = getDifferentialZeroOfType(builder, primalType);
-
// If a pair type can be formed, this must be non-null.
SLANG_RELEASE_ASSERT(diffArg);
-
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
-
args.add(diffPair);
}
else
@@ -1130,17 +1151,19 @@ struct JVPTranscriber
}
}
- auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+ IRType* diffReturnType = nullptr;
+ diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
SLANG_ASSERT(diffReturnType);
auto callInst = builder->emitCallInst(
diffReturnType,
- diffCall,
+ diffCallee,
args);
+
+ IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst);
+ IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst);
- return InstPair(
- pairBuilder->emitPrimalFieldAccess(builder, callInst),
- pairBuilder->emitDiffFieldAccess(builder, callInst));
+ return InstPair(primalResultValue, diffResultValue);
}
else if(as<IRSpecialize>(origCall->getCallee()) ||
as<IRLookupWitnessMethod>(origCall->getCallee()))
@@ -1396,89 +1419,45 @@ struct JVPTranscriber
return InstPair(diffBlock, diffBlock);
}
- InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract)
+ InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
{
- IRInst* origBase = origExtract->getBase();
+ SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst));
+
+ IRInst* origBase = originalInst->getOperand(0);
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+ auto field = originalInst->getOperand(1);
+ auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>();
+ auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
- auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType());
-
- IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField());
- IRInst* diffExtract = nullptr;
+ IRInst* primalOperands[] = { primalBase, field };
+ IRInst* primalFieldExtract = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 2,
+ primalOperands);
- if (auto diffExtractType = differentiateType(builder, primalExtractType))
+ if (!derivativeRefDecor)
{
- // Check if we have a getter.
- if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>())
- {
-
- IRInst* getterFunc = getterDecoration->getGetterFunc();
-
- // Must be a method with a single parameter.
- SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1);
-
- // Our getter func accepts a _pointer_ to the target type
- // So we have to create a variable and store our type into memory
- // here. This will eventually get optimized out in later passes.
- //
- auto diffTempVar = builder->emitVar(
- diffBase->getDataType());
-
- builder->emitStore(diffTempVar, diffBase);
-
- List<IRInst*> args;
- args.add(diffTempVar);
-
- // Emit a call to the getter. The getter will return a reference type.
- // We need to load from this to go to a non-ptr 'solid' type.
- //
- auto diffGetterCall = builder->emitCallInst(
- as<IRFuncType>(getterFunc->getDataType())->getResultType(),
- getterFunc,
- args);
-
- diffExtract = builder->emitLoad(diffGetterCall);
- }
+ return InstPair(primalFieldExtract, nullptr);
}
- return InstPair(primalExtract, diffExtract);
- }
-
- InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress)
- {
- IRInst* origBase = origAddress->getBase();
- auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto diffBase = findOrTranscribeDiffInst(builder, origBase);
-
- auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType());
+ IRInst* diffFieldExtract = nullptr;
- IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField());
- IRInst* diffAddress = nullptr;
-
- if (auto diffAddressType = differentiateType(builder, primalAddressType))
+ if (auto diffType = differentiateType(builder, primalType))
{
- // If we have a getter associated with this field, we want to use that.
- if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>())
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
- auto getterFunc = getterDecoration->getGetterFunc();
-
- // Add the base differential inst as the argument.
- List<IRInst*> args;
- args.add(diffBase);
-
- diffAddress = builder->emitCallInst(
- as<IRFuncType>(getterFunc->getDataType())->getResultType(),
- getterFunc,
- args);
+ IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() };
+ diffFieldExtract = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 2,
+ diffOperands);
}
-
}
-
- return InstPair(primalAddress, diffAddress);
+ return InstPair(primalFieldExtract, diffFieldExtract);
}
-
InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
{
SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
@@ -1514,7 +1493,6 @@ struct JVPTranscriber
return InstPair(primalGetElementPtr, diffGetElementPtr);
}
-
InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
{
// The loop comes with three blocks.. we just need to transcribe each one
@@ -1640,9 +1618,13 @@ struct JVPTranscriber
as<IRFuncType>(origFunc->getFullType()));
diffFunc->setFullType(diffFuncType);
- // TODO(sai): Replace naming scheme
- // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
- // builder->addNameHintDecoration(diffFunc, jvpName);
+ if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_jvp_" << originalName;
+ builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
// Transcribe children from origFunc into diffFunc
builder->setInsertInto(diffFunc);
@@ -1719,9 +1701,18 @@ struct JVPTranscriber
{
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
+ if (pair.differential)
+ {
+ // Generate name hint for the inst.
+ if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << primalNameHint->getName();
+ builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
+ }
+ }
return pair.differential;
}
-
instsInProgress.Remove(origInst);
getSink()->diagnose(origInst->sourceLoc,
@@ -1789,16 +1780,14 @@ struct JVPTranscriber
getSink()->diagnose(origInst->sourceLoc,
Diagnostics::unexpected,
"should not be attempting to differentiate anything specialized here.");
+ return InstPair(nullptr, nullptr);
case kIROp_lookup_interface_method:
return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
case kIROp_FieldExtract:
- return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst));
-
case kIROp_FieldAddress:
- return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst));
-
+ return transcribeFieldExtract(builder, origInst);
case kIROp_getElement:
case kIROp_getElementPtr:
return transcribeGetElement(builder, origInst);
@@ -1942,11 +1931,6 @@ struct JVPDerivativeContext
// Temporary fix: Move generated types, if any, to before their use locations.
(&pairBuilderStorage)->relocateNewTypes(builder);
- // Remove all kIROp_DifferentiableTypeDictionary instructions and
- // kIROp_DifferentialGetterDecoration decorations
- //
- modified |= stripDiffTypeInformation(builder, module->getModuleInst());
-
return modified;
}
@@ -1954,7 +1938,6 @@ struct JVPDerivativeContext
{
if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
return jvpDefinition->getJVPFunc();
-
return nullptr;
}
@@ -2166,7 +2149,7 @@ struct JVPDerivativeContext
return modified;
}
- bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent)
+ bool stripDiffTypeInformation(IRInst* parent)
{
bool modified = false;
@@ -2175,22 +2158,18 @@ struct JVPDerivativeContext
{
auto nextChild = child->getNextInst();
- if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ switch (child->getOp())
{
+ case kIROp_DifferentiableTypeDictionary:
child->removeAndDeallocate();
child = nextChild;
modified = true;
continue;
}
- if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>())
- {
- getterDecoration->removeAndDeallocate();
- }
-
if (child->getFirstChild() != nullptr)
{
- modified |= stripDiffTypeInformation(builder, child);
+ modified |= stripDiffTypeInformation(child);
}
child = nextChild;
@@ -2311,8 +2290,30 @@ bool processJVPDerivativeMarkers(
eliminateDeadCode(module, options);
JVPDerivativeContext context(module, sink);
+ bool changed = context.processModule();
+ changed |= context.stripDiffTypeInformation(module->getModuleInst());
+ return changed;
+}
- return context.processModule();
+void stripAutoDiffDecorations(IRModule* module)
+{
+ for (auto inst : module->getGlobalInsts())
+ {
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_JVPDerivativeReferenceDecoration:
+ case kIROp_JVPDerivativeMemberReferenceDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+ }
}
}