summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp824
1 files changed, 778 insertions, 46 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 4c7a132d0..7f5979a87 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -2017,6 +2017,664 @@ struct JVPTranscriber
}
};
+
+struct BackwardDiffTranscriber
+{
+
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary<IRInst*, IRInst*> orginalToTranscribed;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ List<InstPair> followUpFunctionsToTranscribe;
+
+ // Map that stores the upper gradient given an IRInst*
+ Dictionary<IRInst*, List<IRInst*>> upperGradients;
+ Dictionary<IRInst*, IRInst*> primalToDiffPair;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary<InstPair, IRInst*> differentialPairTypes;
+
+ BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : autoDiffSharedContext(shared)
+ , sink(inSink)
+ , differentiableTypeConformanceContext(shared)
+ , sharedBuilder(inSharedBuilder)
+ {}
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List<IRType*> newParameterTypes;
+ IRType* diffReturnType;
+
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto origType = funcType->getParamType(i);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ newParameterTypes.add(inoutDiffPairType);
+ }
+ else
+ newParameterTypes.add(origType);
+ }
+
+ newParameterTypes.add(funcType->getResultType());
+
+ diffReturnType = builder->getVoidType();
+
+ return builder->getFuncType(newParameterTypes, diffReturnType);
+ }
+
+ IRWitnessTable* getDifferentialBottomWitness()
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ SLANG_ASSERT(result);
+ return result;
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ if (result)
+ return result;
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto diffType = differentiateType(&builder, diffPairType->getValueType());
+ auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+ if (!witness)
+ witness = getDifferentialBottomWitness();
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType)
+ {
+ IRInst* diffType = nullptr;
+ if (!orginalToTranscribed.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ orginalToTranscribed[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = origType;
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ }
+ }
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+ {
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
+ }
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+ }
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
+ }
+
+
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ // Returns "dp<var-name>" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRBlock* diffBlock = subBuilder.emitBlock();
+
+ subBuilder.setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->copyParam(&subBuilder, param);
+
+ // The extra param for input gradient
+ auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType());
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->copyInst(&subBuilder, child);
+
+ auto lastInst = diffBlock->getLastOrdinaryInst();
+ List<IRInst*> grads = { gradParam };
+ upperGradients.Add(lastInst, grads);
+ for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
+ {
+ auto upperGrads = upperGradients.TryGetValue(child);
+ if (!upperGrads)
+ continue;
+ if (upperGrads->getCount() > 1)
+ {
+ auto sumGrad = upperGrads->getFirst();
+ for (auto i = 1; i < upperGrads->getCount(); i++)
+ {
+ sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
+ }
+ this->transcribeInstBackward(&subBuilder, child, sumGrad);
+ }
+ else
+ this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst());
+ }
+
+ subBuilder.emitReturn();
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ primalFunc = origFunc;
+
+ auto diffFunc = builder.createFunc();
+
+ SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ &builder,
+ as<IRFuncType>(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_bwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
+ builder.addBackwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder.addBackwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ // Transcribe children from origFunc into diffFunc
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribeBlock(&builder, block);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ IRInst* diffParam = builder->emitParam(inoutDiffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffParam);
+ auto paramValue = builder->emitLoad(diffParam);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ orginalToTranscribed.Add(origParam, primal);
+ primalToDiffPair.Add(primal, diffParam);
+
+ return diffParam;
+ }
+
+
+ return cloneInst(&cloneEnv, builder, origParam);
+ }
+
+ InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
+
+ IRInst* primalLeft;
+ if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
+ {
+ primalLeft = origLeft;
+ }
+ IRInst* primalRight;
+ if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
+ {
+ primalRight = origRight;
+ }
+
+ auto resultType = origArith->getDataType();
+ IRInst* newInst;
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ newInst = builder->emitAdd(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Mul:
+ newInst = builder->emitMul(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Sub:
+ newInst = builder->emitSub(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Div:
+ newInst = builder->emitDiv(resultType, primalLeft, primalRight);
+ break;
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+ orginalToTranscribed.Add(origArith, newInst);
+ return InstPair(newInst, nullptr);
+ }
+
+ IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto lhs = origArith->getOperand(0);
+ auto rhs = origArith->getOperand(1);
+
+ if (as<IRInOutType>(lhs->getDataType()))
+ {
+ lhs = builder->emitLoad(lhs);
+ lhs = builder->emitDifferentialPairGetPrimal(lhs);
+ }
+ if (as<IRInOutType>(rhs->getDataType()))
+ {
+ rhs = builder->emitLoad(rhs);
+ rhs = builder->emitDifferentialPairGetPrimal(rhs);
+ }
+
+ IRInst* leftGrad;
+ IRInst* rightGrad;
+
+
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ leftGrad = grad;
+ rightGrad = grad;
+ break;
+ case kIROp_Mul:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
+ break;
+ case kIROp_Sub:
+ leftGrad = grad;
+ rightGrad = builder->emitNeg(grad->getDataType(), grad);
+ break;
+ case kIROp_Div:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
+ break;
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+
+ lhs = origArith->getOperand(0);
+ rhs = origArith->getOperand(1);
+ if (auto leftGrads = upperGradients.TryGetValue(lhs))
+ {
+ leftGrads->add(leftGrad);
+ }
+ else
+ {
+ upperGradients.Add(lhs, leftGrad);
+ }
+ if (auto rightGrads = upperGradients.TryGetValue(rhs))
+ {
+ rightGrads->add(rightGrad);
+ }
+ else
+ {
+ upperGradients.Add(rhs, rightGrad);
+ }
+
+ return nullptr;
+ }
+
+ InstPair copyInst(IRBuilder* builder, IRInst* origInst)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParam(builder, as<IRParam>(origInst));
+
+ case kIROp_Return:
+ return InstPair(nullptr, nullptr);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return copyBinaryArith(builder, origInst);
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
+ {
+ IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
+ auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
+ auto paramValue = builder->emitLoad(param);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ auto diff = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ paramValue
+ );
+ auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
+ auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
+ auto store = builder->emitStore(param, updatedParam);
+
+ return store;
+ }
+
+ IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParamBackward(builder, as<IRParam>(origInst), grad);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return transcribeBinaryArithBackward(builder, origInst, grad);
+
+ case kIROp_DifferentialPairGetPrimal:
+ {
+ if (auto param = primalToDiffPair.TryGetValue(origInst))
+ {
+ if (auto leftGrads = upperGradients.TryGetValue(*param))
+ {
+ leftGrads->add(grad);
+ }
+ else
+ {
+ upperGradients.Add(*param, grad);
+ }
+ }
+ else
+ SLANG_ASSERT(0);
+ return nullptr;
+ }
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return nullptr;
+ }
+};
+
+
struct JVPDerivativeContext : public InstPassBase
{
@@ -2034,7 +2692,7 @@ struct JVPDerivativeContext : public InstPassBase
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
-
+
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
@@ -2059,7 +2717,7 @@ struct JVPDerivativeContext : public InstPassBase
IRInst* lookupJVPReference(IRInst* primalFunction)
{
- if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
+ if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
return jvpDefinition->getForwardDerivativeFunc();
return nullptr;
}
@@ -2069,16 +2727,24 @@ struct JVPDerivativeContext : public InstPassBase
//
bool processReferencedFunctions(IRBuilder* builder)
{
- List<IRForwardDifferentiate*> autoDiffWorkList;
+ List<IRInst*> autoDiffWorkList;
for (;;)
{
// Collect all `ForwardDifferentiate` insts from the module.
autoDiffWorkList.clear();
- processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst)
- {
- autoDiffWorkList.add(fwdDiffInst);
- });
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
+ });
if (autoDiffWorkList.getCount() == 0)
break;
@@ -2086,42 +2752,59 @@ struct JVPDerivativeContext : public InstPassBase
// Process collected `ForwardDifferentiate` insts and replace them with placeholders for
// differentiated functions.
transcriberStorage.followUpFunctionsToTranscribe.clear();
+ backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
- for (auto fwdDiffInst : autoDiffWorkList)
+ for (auto differentiateInst : autoDiffWorkList)
{
- auto baseInst = fwdDiffInst->getBaseFn();
+ IRInst* baseInst = differentiateInst->getOperand(0);
+
if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst))
{
- if (auto existingDiffFunc = lookupJVPReference(baseFunction))
+ if (as<IRForwardDifferentiate>(differentiateInst))
{
- fwdDiffInst->replaceUsesWith(existingDiffFunc);
- fwdDiffInst->removeAndDeallocate();
- }
- else if (isMarkedForForwardDifferentiation(baseFunction))
- {
- if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ if (auto existingDiffFunc = lookupJVPReference(baseFunction))
{
- IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
- SLANG_ASSERT(diffFunc);
- fwdDiffInst->replaceUsesWith(diffFunc);
- fwdDiffInst->removeAndDeallocate();
+ differentiateInst->replaceUsesWith(existingDiffFunc);
+ differentiateInst->removeAndDeallocate();
}
- else
+ else if (isMarkedForForwardDifferentiation(baseFunction))
{
- // TODO(Sai): This would probably be better with a more specific
- // error code.
- getSink()->diagnose(fwdDiffInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
+ if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ {
+ IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
}
}
- else
+ else if (as<IRBackwardDifferentiate>(differentiateInst))
{
- // TODO(Sai): This would probably be better with a more specific
- // error code.
- getSink()->diagnose(fwdDiffInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Cannot differentiate functions not marked for differentiation");
+ if (isMarkedForBackwardDifferentiation(baseFunction))
+ {
+ if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ {
+ IRInst* diffFunc =
+ backwardTranscriberStorage
+ .transcribeFuncHeader(builder, (IRFunc*)baseFunction)
+ .differential;
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
+ }
}
}
}
@@ -2136,6 +2819,16 @@ struct JVPDerivativeContext : public InstPassBase
transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
}
+ followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
+ {
+ auto diffFunc = as<IRFunc>(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as<IRFunc>(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
+ }
// Transcribing the function body really shouldn't produce more follow up function body work.
// However it may produce new `ForwardDifferentiate` instructions, which we collect and process
@@ -2159,7 +2852,7 @@ struct JVPDerivativeContext : public InstPassBase
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
-
+
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
bool isTrivial = false;
@@ -2182,7 +2875,7 @@ struct JVPDerivativeContext : public InstPassBase
return result;
}
}
-
+
return nullptr;
}
@@ -2213,7 +2906,7 @@ struct JVPDerivativeContext : public InstPassBase
return primalFieldExtract;
}
}
-
+
return nullptr;
}
@@ -2426,7 +3119,7 @@ struct JVPDerivativeContext : public InstPassBase
//
bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable)
{
- for(auto decoration = callable->getFirstDecoration();
+ for (auto decoration = callable->getFirstDecoration();
decoration;
decoration = decoration->getNextDecoration())
{
@@ -2438,11 +3131,11 @@ struct JVPDerivativeContext : public InstPassBase
return false;
}
- IRStringLit* getForwardDerivativeFuncName(IRInst* func)
+ IRStringLit* getForwardDerivativeFuncName(IRInst* func)
{
IRBuilder builder(&sharedBuilderStorage);
builder.setInsertBefore(func);
-
+
IRStringLit* name = nullptr;
if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
{
@@ -2456,17 +3149,54 @@ struct JVPDerivativeContext : public InstPassBase
return name;
}
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
+ //
+ bool isMarkedForBackwardDifferentiation(IRGlobalValueWithCode* callable)
+ {
+ for (auto decoration = callable->getFirstDecoration();
+ decoration;
+ decoration = decoration->getNextDecoration())
+ {
+ if (decoration->getOp() == kIROp_BackwardDifferentiableDecoration)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice());
+ }
+
+ return name;
+ }
+
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
InstPassBase(module),
sink(sink),
autoDiffSharedContextStorage(module->getModuleInst()),
- transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage)
+ transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage),
+ backwardTranscriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage, sink)
{
autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage;
pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage;
transcriberStorage.sink = sink;
transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
transcriberStorage.pairBuilder = &(pairBuilderStorage);
+ backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
}
protected:
@@ -2474,10 +3204,12 @@ protected:
// processing instructions while maintaining state.
//
JVPTranscriber transcriberStorage;
-
+
+ BackwardDiffTranscriber backwardTranscriberStorage;
+
// Diagnostic object from the compile request for
// error messages.
- DiagnosticSink* sink;
+ DiagnosticSink* sink;
// Context to find and manage the witness tables for types
// implementing `IDifferentiable`
@@ -2490,11 +3222,11 @@ protected:
// Set up context and call main process method.
//
-bool processForwardDifferentiableFuncs(
- IRModule* module,
- DiagnosticSink* sink,
- IRJVPDerivativePassOptions const&)
-{
+bool processDifferentiableFuncs(
+ IRModule* module,
+ DiagnosticSink* sink,
+ IRJVPDerivativePassOptions const&)
+{
// Simplify module to remove dead code.
IRDeadCodeEliminationOptions options;
options.keepExportsAlive = true;