diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-21 15:25:38 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-21 15:25:38 -0800 |
| commit | 6dbdb74dbdc20783a0429229c21604a3d08d28f8 (patch) | |
| tree | 910e2dd7b7b296ae5c285dbbb73114b381ef529a /source | |
| parent | 887842933c0734196729d5525de9835eb48b3855 (diff) | |
Further unify the autodiff passes. (#2574)
* Further unify the autodiff passes.
* Fix clang compilation error.
* Rename ForwardDerivativeTranscriber->ForwardDiffTranscriber.
* Remove unused fields from Transcriber classes.
* More small cleanups.
* Cleanup.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1090 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 146 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 567 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 88 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 847 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 129 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 205 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 24 |
11 files changed, 1604 insertions, 1577 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index d1e9f91ec..dbf79b5f8 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -11,123 +11,7 @@ namespace Slang { -static IRInst* _unwrapAttributedType(IRInst* type) -{ - while (auto attrType = as<IRAttributedType>(type)) - type = attrType->getBaseType(); - return type; -} - -DiagnosticSink* ForwardDerivativeTranscriber::getSink() -{ - SLANG_ASSERT(sink); - return sink; -} - -void ForwardDerivativeTranscriber::mapDifferentialInst(IRInst* origInst, IRInst* diffInst) -{ - if (hasDifferentialInst(origInst)) - { - if (lookupDiffInst(origInst) != diffInst) - { - SLANG_UNEXPECTED("Inconsistent differential mappings"); - } - } - else - { - instMapD.Add(origInst, diffInst); - } -} - -void ForwardDerivativeTranscriber::mapPrimalInst(IRInst* origInst, IRInst* primalInst) -{ - if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) - { - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "inconsistent primal instruction for original"); - } - else - { - cloneEnv.mapOldValToNew[origInst] = primalInst; - } -} - -IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst) -{ - return instMapD[origInst]; -} - -IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst, IRInst* defaultInst) -{ - return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst; -} - -bool ForwardDerivativeTranscriber::hasDifferentialInst(IRInst* origInst) -{ - return instMapD.ContainsKey(origInst); -} - -bool ForwardDerivativeTranscriber::shouldUseOriginalAsPrimal(IRInst* origInst) -{ - if (as<IRGlobalValueWithCode>(origInst)) - return true; - if (origInst->parent && origInst->parent->getOp() == kIROp_Module) - return true; - return false; -} - -IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst) -{ - if (!origInst) - return nullptr; - if (shouldUseOriginalAsPrimal(origInst)) - return origInst; - return cloneEnv.mapOldValToNew[origInst]; -} - -IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) -{ - if (!origInst) - return nullptr; - return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; -} - -bool ForwardDerivativeTranscriber::hasPrimalInst(IRInst* origInst) -{ - if (!origInst) - return true; - if (shouldUseOriginalAsPrimal(origInst)) - return true; - return cloneEnv.mapOldValToNew.ContainsKey(origInst); -} - -IRInst* ForwardDerivativeTranscriber::findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) -{ - if (!hasDifferentialInst(origInst)) - { - transcribe(builder, origInst); - SLANG_ASSERT(hasDifferentialInst(origInst)); - } - - return lookupDiffInst(origInst); -} - -IRInst* ForwardDerivativeTranscriber::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) -{ - if (shouldUseOriginalAsPrimal(origInst)) - return origInst; - - if (!hasPrimalInst(origInst)) - { - transcribe(builder, origInst); - SLANG_ASSERT(hasPrimalInst(origInst)); - } - - return lookupPrimalInst(origInst); -} - -IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) +IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) { List<IRType*> newParameterTypes; IRType* diffReturnType; @@ -135,7 +19,7 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b for (UIndex i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); - origType = (IRType*) lookupPrimalInst(origType, origType); + origType = (IRType*) findOrTranscribePrimalInst(builder, origType); if (auto diffPairType = tryGetDiffPairType(builder, origType)) newParameterTypes.add(diffPairType); else @@ -145,7 +29,7 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b // Transcribe return type to a pair. // This will be void if the primal return type is non-differentiable. // - auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); + auto origResultType = (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType()); if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) diffReturnType = returnPairType; else @@ -154,320 +38,10 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b return builder->getFuncType(newParameterTypes, diffReturnType); } -// Get or construct `:IDifferentiable` conformance for a DifferentiablePair. -IRWitnessTable* ForwardDerivativeTranscriber::getDifferentialPairWitness(IRInst* inDiffPairType) -{ - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); - SLANG_ASSERT(diffPairType); - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = differentiateType(&builder, diffPairType); - - // And place it in the synthesized witness table. - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - - // Record this in the context for future lookups - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - - return table; -} - -IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) -{ - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); -} - -IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType) -{ - IRBuilder builder(sharedBuilder); - if (!primalType->next) - builder.setInsertInto(primalType->parent); - else - builder.setInsertBefore(primalType->next); - - IRInst* witness = as<IRWitnessTable>( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - - if (!witness) - { - if (auto primalPairType = as<IRDifferentialPairType>(primalType)) - { - witness = getDifferentialPairWitness(primalPairType); - } - else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) - { - differentiateExtractExistentialType(&builder, extractExistential, witness); - } - } - - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); -} - -IRType* ForwardDerivativeTranscriber::differentiateType(IRBuilder* builder, IRType* origType) -{ - IRInst* diffType = nullptr; - if (!instMapD.TryGetValue(origType, diffType)) - { - diffType = _differentiateTypeImpl(builder, origType); - instMapD[origType] = diffType; - } - return (IRType*)diffType; -} - -IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder, IRType* origType) -{ - if (auto ptrType = as<IRPtrTypeBase>(origType)) - return builder->getPtrType( - origType->getOp(), - differentiateType(builder, ptrType->getValueType())); - - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = lookupPrimalInst(origType, origType); - - // Special case certain compound types (PtrType, FuncType, etc..) - // otherwise try to lookup a differential definition for the given type. - // If one does not exist, then we assume it's not differentiable. - // - switch (primalType->getOp()) - { - case kIROp_Param: - if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); - else if (as<IRWitnessTableType>(primalType->getDataType())) - return (IRType*)primalType; - - case kIROp_ArrayType: - { - auto primalArrayType = as<IRArrayType>(primalType); - if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) - return builder->getArrayType( - diffElementType, - primalArrayType->getElementCount()); - else - return nullptr; - } - - case kIROp_DifferentialPairType: - { - auto primalPairType = as<IRDifferentialPairType>(primalType); - return getOrCreateDiffPairType( - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); - } - - case kIROp_FuncType: - return differentiateFunctionType(builder, as<IRFuncType>(primalType)); - - case kIROp_OutType: - if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) - return builder->getOutType(diffValueType); - else - return nullptr; - - case kIROp_InOutType: - if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) - return builder->getInOutType(diffValueType); - else - return nullptr; - - case kIROp_ExtractExistentialType: - { - IRInst* wt = nullptr; - return differentiateExtractExistentialType(builder, as<IRExtractExistentialType>(primalType), wt); - } - - case kIROp_TupleType: - { - auto tupleType = as<IRTupleType>(primalType); - List<IRType*> diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); - } - - default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); - } -} - - // Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. -bool _findDifferentiableInterfaceLookupPathImpl( - HashSet<IRInst*>& processedTypes, - IRInterfaceType* idiffType, - IRInterfaceType* type, - List<IRInterfaceRequirementEntry*>& currentPath) -{ - if (processedTypes.Contains(type)) - return false; - processedTypes.Add(type); - - List<IRInterfaceRequirementEntry*> lookupKeyPath; - for (UInt i = 0; i < type->getOperandCount(); i++) - { - auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i)); - if (!entry) continue; - if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) - { - currentPath.add(entry); - if (wt->getConformanceType() == idiffType) - { - return true; - } - else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) - { - if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) - return true; - } - currentPath.removeLast(); - } - } - return false; -} - -List<IRInterfaceRequirementEntry*> _findDifferentiableInterfaceLookupPath( - IRInterfaceType* idiffType, - IRInterfaceType* type) -{ - List<IRInterfaceRequirementEntry*> currentPath; - HashSet<IRInst*> processedTypes; - _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); - return currentPath; -} - -IRType* ForwardDerivativeTranscriber::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable) -{ - witnessTable = nullptr; - - // Search for IDifferentiable conformance. - auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(origType->getOperand(0)->getDataType())); - if (!interfaceType) - return nullptr; - List<IRInterfaceRequirementEntry*> lookupKeyPath = _findDifferentiableInterfaceLookupPath( - autoDiffSharedContext->differentiableInterfaceType, interfaceType); - - if (lookupKeyPath.getCount()) - { - // `interfaceType` does conform to `IDifferentiable`. - witnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0)); - for (auto node : lookupKeyPath) - { - witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey()); - } - auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), witnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); - return (IRType*)diffType; - } - return nullptr; -} - -IRType* ForwardDerivativeTranscriber::tryGetDiffPairType(IRBuilder* builder, IRType* primalType) -{ - // If this is a PtrType (out, inout, etc..), then create diff pair from - // value type and re-apply the appropropriate PtrType wrapper. - // - if (auto origPtrType = as<IRPtrTypeBase>(primalType)) - { - if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); - else - return nullptr; - } - auto diffType = differentiateType(builder, primalType); - if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); - return nullptr; -} - -InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRParam* origParam) -{ - auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); - // Do not differentiate generic type (and witness table) parameters - if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) - { - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - - // Is this param a phi node or a function parameter? - auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent()); - bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); - if (isFuncParam) - { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as<IRDifferentialPairType>(diffPairType)) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) - { - auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); - - return InstPair( - builder->emitDifferentialPairAddressPrimal(diffPairParam), - builder->emitDifferentialPairAddressDifferential( - builder->getPtrType( - kIROp_PtrType, - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), - diffPairParam)); - } - } - - auto primalInst = cloneInst(&cloneEnv, builder, origParam); - if (auto primalParam = as<IRParam>(primalInst)) - { - SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); - primalParam->removeFromParent(); - builder->getInsertLoc().getBlock()->addParam(primalParam); - } - return InstPair(primalInst, nullptr); - } - else - { - auto primal = cloneInst(&cloneEnv, builder, origParam); - IRInst* diff = nullptr; - if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) - { - diff = builder->emitParam(diffType); - } - return InstPair(primal, diff); - } -} - // Returns "d<var-name>" to use as a name hint for variables and parameters. // If no primal name is available, returns a blank string. // -String ForwardDerivativeTranscriber::getJVPVarName(IRInst* origVar) +String ForwardDiffTranscriber::getJVPVarName(IRInst* origVar) { if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { @@ -477,20 +51,7 @@ String ForwardDerivativeTranscriber::getJVPVarName(IRInst* origVar) return String(""); } -// Returns "dp<var-name>" to use as a name hint for parameters. -// If no primal name is available, returns a blank string. -// -String ForwardDerivativeTranscriber::makeDiffPairName(IRInst* origVar) -{ - if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); -} - -InstPair ForwardDerivativeTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) +InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) { if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) { @@ -507,7 +68,7 @@ InstPair ForwardDerivativeTranscriber::transcribeVar(IRBuilder* builder, IRVar* return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); } -InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) +InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) { SLANG_ASSERT(origArith->getOperandCount() == 2); @@ -587,7 +148,7 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, return InstPair(primalArith, nullptr); } -InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) +InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) { SLANG_ASSERT(origLogic->getOperandCount() == 2); @@ -604,7 +165,7 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, SLANG_UNEXPECTED("Logical operation with non-boolean result"); } -InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad) +InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); auto primalPtr = lookupPrimalInst(origPtr, nullptr); @@ -637,7 +198,7 @@ InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad return InstPair(primalLoad, diffLoad); } -InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore) +InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore) { IRInst* origStoreLocation = origStore->getPtr(); IRInst* origStoreVal = origStore->getVal(); @@ -679,67 +240,18 @@ InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRSto return InstPair(primalStore, nullptr); } -InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRReturn* origReturn) -{ - IRInst* origReturnVal = origReturn->getVal(); - - auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); - if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal)) - { - // If the return value is itself a function, generic or a struct then this - // is likely to be a generic scope. In this case, we lookup the differential - // and return that. - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); - - // Neither of these should be nullptr. - SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); - IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); - builder->markInstAsMixedDifferential(diffReturn, nullptr); - - return InstPair(diffReturn, diffReturn); - } - else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) - { - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); - if(!diffReturnVal) - diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); - - // If the pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffReturnVal); - - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); - builder->markInstAsMixedDifferential(diffPair, pairType); - - IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); - builder->markInstAsMixedDifferential(pairReturn, pairType); - - return InstPair(pairReturn, pairReturn); - } - else - { - // If the return type is not differentiable, emit the primal value only. - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - - IRInst* primalReturn = builder->emitReturn(primalReturnVal); - return InstPair(primalReturn, nullptr); - - } -} - // Since int/float literals are sometimes nested inside an IRConstructor // instruction, we check to make sure that the nested instr is a constant // and then return nullptr. Literals do not need to be differentiated. // -InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) +InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) { IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); // Check if the output type can be differentiated. If it cannot be // differentiated, don't differentiate the inst // - auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); + auto primalConstructType = (IRType*)findOrTranscribePrimalInst(builder, origConstruct->getDataType()); if (auto diffConstructType = differentiateType(builder, primalConstructType)) { UCount operandCount = origConstruct->getOperandCount(); @@ -755,7 +267,7 @@ InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, I else { auto operandDataType = origConstruct->getOperand(ii)->getDataType(); - operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); + operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType); diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); } } @@ -778,7 +290,7 @@ InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, I // an appropriate call list based on whichever parameters have differentials // in the current transcription context. // -InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall) +InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall) { IRInst* origCallee = origCall->getCallee(); @@ -902,7 +414,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall } } -InstPair ForwardDerivativeTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) +InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) { IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); @@ -924,7 +436,7 @@ InstPair ForwardDerivativeTranscriber::transcribeSwizzle(IRBuilder* builder, IRS return InstPair(primalSwizzle, nullptr); } -InstPair ForwardDerivativeTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) { IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); @@ -953,7 +465,7 @@ InstPair ForwardDerivativeTranscriber::transcribeByPassthrough(IRBuilder* builde diffOperands.getBuffer())); } -InstPair ForwardDerivativeTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst) { switch(origInst->getOp()) { @@ -1018,7 +530,7 @@ InstPair ForwardDerivativeTranscriber::transcribeControlFlow(IRBuilder* builder, return InstPair(nullptr, nullptr); } -InstPair ForwardDerivativeTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst) { switch(origInst->getOp()) { @@ -1038,7 +550,7 @@ InstPair ForwardDerivativeTranscriber::transcribeConst(IRBuilder* builder, IRIns return InstPair(nullptr, nullptr); } -IRInst* ForwardDerivativeTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key) +IRInst* ForwardDiffTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key) { for (UInt i = 0; i < type->getOperandCount(); i++) { @@ -1051,7 +563,7 @@ IRInst* ForwardDerivativeTranscriber::findInterfaceRequirement(IRInterfaceType* return nullptr; } -InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) +InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) { auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); List<IRInst*> primalArgs; @@ -1120,126 +632,7 @@ InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder, } } -InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst) -{ - auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable()); - auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey()); - auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType()); - auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey); - - auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType())); - if (!interfaceType) - { - return InstPair(primal, nullptr); - } - auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); - if (!dict) - { - return InstPair(primal, nullptr); - } - - for (auto child : dict->getChildren()) - { - if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child)) - { - if (item->getOperand(0) == lookupInst->getRequirementKey()) - { - auto diffKey = item->getOperand(1); - if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) - { - auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); - return InstPair(primal, diff); - } - break; - } - } - } - return InstPair(primal, nullptr); -} - -// In differential computation, the 'default' differential value is always zero. -// This is a consequence of differential computing being inherently linear. As a -// result, it's useful to have a method to generate zero literals of any (arithmetic) type. -// The current implementation requires that types are defined linearly. -// -IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) -{ - if (auto diffType = differentiateType(builder, primalType)) - { - switch (diffType->getOp()) - { - case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); - } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - if (!zeroMethod) - { - // if the differential type itself comes from a witness lookup, we can just lookup the - // zero method from the same witness table. - if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) - { - auto wt = lookupInterface->getWitnessTable(); - zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); - } - } - SLANG_RELEASE_ASSERT(zeroMethod); - - auto emptyArgList = List<IRInst*>(); - - auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); - builder->markInstAsDifferential(callInst, primalType); - - return callInst; - } - else - { - if (isScalarIntegerType(primalType)) - { - return builder->getIntValue(primalType, 0); - } - - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } -} - -InstPair ForwardDerivativeTranscriber::transcribeBlock(IRBuilder* builder, IRBlock* origBlock) -{ - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); - - IRInst* diffBlock = subBuilder.emitBlock(); - - // Note: for blocks, we setup the mapping _before_ - // processing the children since we could encounter - // a lookup while processing the children. - // - mapPrimalInst(origBlock, diffBlock); - mapDifferentialInst(origBlock, diffBlock); - - subBuilder.setInsertInto(diffBlock); - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->transcribe(&subBuilder, param); - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->transcribe(&subBuilder, child); - - return InstPair(diffBlock, diffBlock); -} - -InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) +InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) { SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst)); @@ -1247,7 +640,7 @@ InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto field = originalInst->getOperand(1); auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); - auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); IRInst* primalOperands[] = { primalBase, field }; IRInst* primalFieldExtract = builder->emitIntrinsicInst( @@ -1278,7 +671,7 @@ InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder return InstPair(primalFieldExtract, diffFieldExtract); } -InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) +InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) { SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); @@ -1286,7 +679,7 @@ InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder, auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); - auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origGetElementPtr->getDataType()); IRInst* primalOperands[] = {primalBase, primalIndex}; IRInst* primalGetElementPtr = builder->emitIntrinsicInst( @@ -1313,7 +706,7 @@ InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder, return InstPair(primalGetElementPtr, diffGetElementPtr); } -InstPair ForwardDerivativeTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) +InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) { // The loop comes with three blocks.. we just need to transcribe each one // and assemble the new loop instruction. @@ -1351,7 +744,7 @@ InstPair ForwardDerivativeTranscriber::transcribeLoop(IRBuilder* builder, IRLoop return InstPair(diffLoop, diffLoop); } -InstPair ForwardDerivativeTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) +InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) { // IfElse Statements come with 4 blocks. We transcribe each block into it's // linear form, and then wire them up in the same way as the original if-else @@ -1395,7 +788,7 @@ InstPair ForwardDerivativeTranscriber::transcribeIfElse(IRBuilder* builder, IRIf return InstPair(diffLoop, diffLoop); } -InstPair ForwardDerivativeTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) +InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) { auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); SLANG_ASSERT(primalVal); @@ -1406,21 +799,16 @@ InstPair ForwardDerivativeTranscriber::transcribeMakeDifferentialPair(IRBuilder* auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); SLANG_ASSERT(diffDiffVal); - auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); + auto primalPair = builder->emitMakeDifferentialPair( + tryGetDiffPairType(builder, primalVal->getDataType()), primalVal, diffPrimalVal); auto diffPair = builder->emitMakeDifferentialPair( - differentiateType(builder, origInst->getDataType()), + tryGetDiffPairType(builder, differentiateType(builder, primalVal->getDataType())), primalDiffVal, diffDiffVal); return InstPair(primalPair, diffPair); } -InstPair ForwardDerivativeTranscriber::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) -{ - auto primal = cloneInst(&cloneEnv, builder, origInst); - return InstPair(primal, nullptr); -} - -InstPair ForwardDerivativeTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) { SLANG_ASSERT( origInst->getOp() == kIROp_DifferentialPairGetDifferential || @@ -1444,11 +832,88 @@ InstPair ForwardDerivativeTranscriber::transcribeDifferentialPairGetElement(IRBu return InstPair(primalResult, diffResult); } +InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst) +{ + IRInst* origBase = origInst->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); + + IRInst* primalResult = builder->emitIntrinsicInst( + primalType, + origInst->getOp(), + 1, + &primalBase); + + IRInst* diffResult = nullptr; + + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + diffResult = builder->emitIntrinsicInst( + diffType, + origInst->getOp(), + 1, + &diffBase); + } + } + return InstPair(primalResult, diffResult); +} + +InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst) +{ + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); + + List<IRInst*> primalArgs; + for (UInt i = 0; i < origInst->getOperandCount(); i++) + { + auto primalArg = findOrTranscribePrimalInst(builder, origInst->getOperand(i)); + primalArgs.add(primalArg); + } + + IRInst* primalResult = builder->emitIntrinsicInst( + primalType, + origInst->getOp(), + primalArgs.getCount(), + primalArgs.getBuffer()); + + IRInst* diffResult = nullptr; + + if (auto diffType = differentiateType(builder, primalType)) + { + List<IRInst*> diffArgs; + for (UInt i = 0; i < origInst->getOperandCount(); i++) + { + auto arg = findOrTranscribeDiffInst(builder, origInst->getOperand(i)); + if (arg) + { + diffArgs.add(arg); + } + else if (i == 0) + { + // If we can't diff the first operand (base), abort now. + break; + } + } + if (diffArgs.getCount()) + { + diffResult = builder->emitIntrinsicInst( + diffType, + origInst->getOp(), + diffArgs.getCount(), + diffArgs.getBuffer()); + } + } + return InstPair(primalResult, diffResult); +} + // Create an empty func to represent the transcribed func of `origFunc`. -InstPair ForwardDerivativeTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) +InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origFunc); + if (auto bwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>()) + return InstPair(origFunc, bwdDecor->getForwardDerivativeFunc()); + + IRBuilder builder = *inBuilder; IRFunc* primalFunc = origFunc; @@ -1482,13 +947,17 @@ InstPair ForwardDerivativeTranscriber::transcribeFuncHeader(IRBuilder* inBuilder cloneDecoration(dictDecor, diffFunc); } - auto result = InstPair(primalFunc, diffFunc); - followUpFunctionsToTranscribe.add(result); - return result; + FuncBodyTranscriptionTask task; + task.type = FuncBodyTranscriptionTaskType::Forward; + task.originalFunc = primalFunc; + task.resultFunc = diffFunc; + autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); + + return InstPair(primalFunc, diffFunc); } // Transcribe a function definition. -InstPair ForwardDerivativeTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) +InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { IRBuilder builder(inBuilder->getSharedBuilder()); builder.setInsertInto(diffFunc); @@ -1502,7 +971,7 @@ InstPair ForwardDerivativeTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFu } // Transcribe a generic definition -InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) +InstPair ForwardDiffTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) { auto innerVal = findInnerMostGenericReturnVal(origGeneric); if (auto innerFunc = as<IRFunc>(innerVal)) @@ -1546,69 +1015,7 @@ InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, I return InstPair(primalGeneric, diffGeneric); } -IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* origInst) -{ - // If a differential intstruction is already mapped for - // this original inst, return that. - // - if (auto diffInst = lookupDiffInst(origInst, nullptr)) - { - SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. - return diffInst; - } - - // Otherwise, dispatch to the appropriate method - // depending on the op-code. - // - instsInProgress.Add(origInst); - InstPair pair = transcribeInst(builder, origInst); - instsInProgress.Remove(origInst); - - if (auto primalInst = pair.primal) - { - mapPrimalInst(origInst, pair.primal); - mapDifferentialInst(origInst, pair.differential); - if (pair.differential) - { - switch (pair.differential->getOp()) - { - case kIROp_Func: - case kIROp_Generic: - case kIROp_Block: - // Don't generate again for these. - // Functions already have their names generated in `transcribeFuncHeader`. - break; - default: - // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } - - // Tag the differential inst using a decoration (if it doesn't have one) - if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && - !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>()) - { - // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential - // instead. - // - builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType())); - } - - break; - } - } - return pair.differential; - } - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "failed to transcibe instruction"); - return nullptr; -} - -InstPair ForwardDerivativeTranscriber::transcribeInst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst) { // Handle common SSA-style operations switch (origInst->getOp()) @@ -1695,252 +1102,35 @@ InstPair ForwardDerivativeTranscriber::transcribeInst(IRBuilder* builder, IRInst case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferential: return transcribeDifferentialPairGetElement(builder, origInst); - case kIROp_ExtractExistentialWitnessTable: - case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: - case kIROp_WrapExistential: case kIROp_MakeExistential: - case kIROp_MakeExistentialWithRTTI: + return transcribeSingleOperandInst(builder, origInst); + case kIROp_ExtractExistentialType: + { + IRInst* witnessTable; + return InstPair( + maybeCloneForPrimalInst(builder, origInst), + differentiateExtractExistentialType( + builder, as<IRExtractExistentialType>(origInst), witnessTable)); + } + case kIROp_ExtractExistentialWitnessTable: + return transcribeExtractExistentialWitnessTable(builder, origInst); + case kIROp_WrapExistential: + return transcribeWrapExistential(builder, origInst); + case kIROp_CreateExistentialObject: + // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, + // so we treat this inst as non differentiable. + // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. return trascribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); - } - // If none of the cases have been hit, check if the instruction is a - // type. Only need to explicitly differentiate types if they appear inside a block. - // - if (auto origType = as<IRType>(origInst)) - { - // If this is a generic type, transcibe the parent - // generic and derive the type from the transcribed generic's - // return value. - // - if (as<IRGeneric>(origType->getParent()->getParent()) && - findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType && - !instsInProgress.Contains(origType->getParent()->getParent())) - { - auto origGenericType = origType->getParent()->getParent(); - auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); - auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType)); - return InstPair( - origGenericType, - innerDiffGenericType - ); - } - else if (as<IRBlock>(origType->getParent())) - return InstPair( - cloneInst(&cloneEnv, builder, origType), - differentiateType(builder, origType)); - else - return InstPair( - cloneInst(&cloneEnv, builder, origType), - nullptr); - } - - // Handle instructions with children - switch (origInst->getOp()) - { - case kIROp_Func: - return transcribeFuncHeader(builder, as<IRFunc>(origInst)); - - case kIROp_Block: - return transcribeBlock(builder, as<IRBlock>(origInst)); - - case kIROp_Generic: - return transcribeGeneric(builder, as<IRGeneric>(origInst)); + case kIROp_MakeExistentialWithRTTI: + SLANG_UNEXPECTED("MakeExistentialWithRTTI inst is not expected in autodiff pass."); + break; } - // If we reach this statement, the instruction type is likely unhandled. - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::unimplemented, - "this instruction cannot be differentiated"); - return InstPair(nullptr, nullptr); } -struct ForwardDerivativePass : public InstPassBase -{ - - DiagnosticSink* getSink() - { - return sink; - } - - bool processModule() - { - // TODO(sai): Move this call. - transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); - - IRBuilder builderStorage(this->autodiffContext->sharedBuilder); - IRBuilder* builder = &builderStorage; - - // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by - // generating derivative code for the referenced function. - // - bool modified = processReferencedFunctions(builder); - - return modified; - } - - IRInst* lookupJVPReference(IRInst* primalFunction) - { - if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) - return jvpDefinition->getForwardDerivativeFunc(); - return nullptr; - } - - // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), - // then check that the referenced function is marked correctly for differentiation. - // - bool processReferencedFunctions(IRBuilder* builder) - { - bool changed = false; - List<IRInst*> autoDiffWorkList; - for (;;) - { - // Collect all `ForwardDifferentiate` insts from the module. - autoDiffWorkList.clear(); - processAllInsts([&](IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_ForwardDifferentiate: - // Only process now if the operand is a materialized function. - switch (inst->getOperand(0)->getOp()) - { - case kIROp_Func: - case kIROp_Specialize: - case kIROp_LookupWitness: - autoDiffWorkList.add(inst); - break; - default: - break; - } - break; - default: - break; - } - }); - - if (autoDiffWorkList.getCount() == 0) - break; - - // Process collected `ForwardDifferentiate` insts and replace them with placeholders for - // differentiated functions. - - transcriberStorage.followUpFunctionsToTranscribe.clear(); - - for (auto differentiateInst : autoDiffWorkList) - { - IRInst* baseInst = differentiateInst->getOperand(0); - if (as<IRForwardDifferentiate>(differentiateInst)) - { - if (auto existingDiffFunc = lookupJVPReference(baseInst)) - { - differentiateInst->replaceUsesWith(existingDiffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - IRBuilder subBuilder(*builder); - subBuilder.setInsertBefore(differentiateInst); - IRInst* diffFunc = transcriberStorage.transcribe(&subBuilder, baseInst); - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - changed = true; - } - } - // Actually synthesize the derivatives. - List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe); - for (auto task : followUpWorkList) - { - auto diffFunc = as<IRFunc>(task.differential); - SLANG_ASSERT(diffFunc); - auto primalFunc = as<IRFunc>(task.primal); - SLANG_ASSERT(primalFunc); - - transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); - } - - // Transcribing the function body really shouldn't produce more follow up function body work. - // However it may produce new `ForwardDifferentiate` instructions, which we collect and process - // in the next iteration. - SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0); - - } - return changed; - } - - // Checks decorators to see if the function should - // be differentiated (kIROp_ForwardDifferentiableDecoration) - // - bool isMarkedForForwardDifferentiation(IRInst* callable) - { - if (auto gen = as<IRGeneric>(callable)) - callable = findGenericReturnVal(gen); - return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr; - } - - IRStringLit* getForwardDerivativeFuncName(IRInst* func) - { - IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(func); - - IRStringLit* name = nullptr; - if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) - { - name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); - } - else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) - { - name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); - } - - return name; - } - - ForwardDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) : - InstPassBase(context->moduleInst->getModule()), - sink(sink), - transcriberStorage(context, context->sharedBuilder), - pairBuilderStorage(context), - autodiffContext(context) - { - transcriberStorage.sink = sink; - transcriberStorage.autoDiffSharedContext = context; - transcriberStorage.pairBuilder = &(pairBuilderStorage); - } - -protected: - // A transcriber object that handles the main job of - // processing instructions while maintaining state. - // - ForwardDerivativeTranscriber transcriberStorage; - - // Diagnostic object from the compile request for - // error messages. - DiagnosticSink* sink; - - // Shared context. - AutoDiffSharedContext* autodiffContext; - - // Builder for dealing with differential pair types. - DifferentialPairTypeBuilder pairBuilderStorage; - -}; - -// Set up context and call main process method. -// -bool processForwardDerivativeCalls( - AutoDiffSharedContext* autodiffContext, - DiagnosticSink* sink, - ForwardDerivativePassOptions const&) -{ - ForwardDerivativePass fwdPass(autodiffContext, sink); - bool changed = fwdPass.processModule(); - return changed; -} - } diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 678677625..22ebf9d95 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -1,117 +1,18 @@ // slang-ir-autodiff-fwd.h #pragma once -#include "slang-ir.h" -#include "slang-ir-insts.h" -#include "slang-compiler.h" +#include "slang-ir-autodiff-transcriber-base.h" namespace Slang { - template<typename P, typename D> - struct DiffInstPair - { - P primal; - D differential; - DiffInstPair() = default; - DiffInstPair(P primal, D differential) : primal(primal), differential(differential) - {} - HashCode getHashCode() const - { - Hasher hasher; - hasher << primal << differential; - return hasher.getResult(); - } - bool operator ==(const DiffInstPair& other) const - { - return primal == other.primal && differential == other.differential; - } - }; - - typedef DiffInstPair<IRInst*, IRInst*> InstPair; - - -struct ForwardDerivativeTranscriber +struct ForwardDiffTranscriber : AutoDiffTranscriberBase { - - // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent - // their differential values. - Dictionary<IRInst*, IRInst*> instMapD; - - // Set of insts currently being transcribed. Used to avoid infinite loops. - HashSet<IRInst*> instsInProgress; - - // Cloning environment to hold mapping from old to new copies for the primal - // instructions. - IRCloneEnv cloneEnv; - - // Diagnostic sink for error messages. - DiagnosticSink* sink; - - // Type conformance information. - AutoDiffSharedContext* autoDiffSharedContext; - - // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct - DifferentialPairTypeBuilder* pairBuilder; - - DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - - List<InstPair> followUpFunctionsToTranscribe; - - SharedIRBuilder* sharedBuilder; - // Witness table that `DifferentialBottom:IDifferential`. - IRWitnessTable* differentialBottomWitness = nullptr; - Dictionary<InstPair, IRInst*> differentialPairTypes; - - ForwardDerivativeTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) - : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) + ForwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink) { - } - DiagnosticSink* getSink(); - - void mapDifferentialInst(IRInst* origInst, IRInst* diffInst); - - void mapPrimalInst(IRInst* origInst, IRInst* primalInst); - - IRInst* lookupDiffInst(IRInst* origInst); - - IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst); - - bool hasDifferentialInst(IRInst* origInst); - - bool shouldUseOriginalAsPrimal(IRInst* origInst); - - IRInst* lookupPrimalInst(IRInst* origInst); - - IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst); - - bool hasPrimalInst(IRInst* origInst); - - IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); - - IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst); - - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType); - - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType); - - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness); - - IRType* getOrCreateDiffPairType(IRInst* primalType); - - IRType* differentiateType(IRBuilder* builder, IRType* origType); - - IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType); - - IRType* differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable); - - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType); - - InstPair transcribeParam(IRBuilder* builder, IRParam* origParam); - // Returns "d<var-name>" to use as a name hint for variables and parameters. // If no primal name is available, returns a blank string. // @@ -132,8 +33,6 @@ struct ForwardDerivativeTranscriber InstPair transcribeStore(IRBuilder* builder, IRStore* origStore); - InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn); - // Since int/float literals are sometimes nested inside an IRConstructor // instruction, we check to make sure that the nested instr is a constant // and then return nullptr. Literals do not need to be differentiated. @@ -158,17 +57,6 @@ struct ForwardDerivativeTranscriber InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); - InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst); - - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); - - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock); - InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst); InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr); @@ -179,12 +67,13 @@ struct ForwardDerivativeTranscriber InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst); - InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst); - InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); - // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc); + InstPair transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeWrapExistential(IRBuilder* builder, IRInst* origInst); + + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override; // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); @@ -192,19 +81,16 @@ struct ForwardDerivativeTranscriber // Transcribe a generic definition InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric); - IRInst* transcribe(IRBuilder* builder, IRInst* origInst); + // Create an empty func to represent the transcribed func of `origFunc`. + virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; - InstPair transcribeInst(IRBuilder* builder, IRInst* origInst); -}; + virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override; - struct ForwardDerivativePassOptions + virtual IROp getDifferentiableMethodDictionaryItemOp() override { - // Nothing for now.. - }; + return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + } - bool processForwardDerivativeCalls( - AutoDiffSharedContext* autodiffContext, - DiagnosticSink* sink, - ForwardDerivativePassOptions const& options = ForwardDerivativePassOptions()); +}; } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 56002231a..cfee49eb1 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -7,83 +7,11 @@ #include "slang-ir-inst-pass-base.h" #include "slang-ir-autodiff-fwd.h" -#include "slang-ir-autodiff-propagate.h" -#include "slang-ir-autodiff-unzip.h" -#include "slang-ir-autodiff-transpose.h" namespace Slang { -struct BackwardDiffTranscriber -{ - // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent - // their differential values. - Dictionary<IRInst*, IRInst*> orginalToTranscribed; - - // Set of insts currently being transcribed. Used to avoid infinite loops. - HashSet<IRInst*> instsInProgress; - - // Cloning environment to hold mapping from old to new copies for the primal - // instructions. - IRCloneEnv cloneEnv; - - // Diagnostic sink for error messages. - DiagnosticSink* sink; - - // Type conformance information. - AutoDiffSharedContext* autoDiffSharedContext; - - // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct - DifferentialPairTypeBuilder* pairBuilder; - - DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - - List<InstPair> followUpFunctionsToTranscribe; - - // Map that stores the upper gradient given an IRInst* - Dictionary<IRInst*, List<IRInst*>> upperGradients; - Dictionary<IRInst*, IRInst*> primalToDiffPair; - - SharedIRBuilder* sharedBuilder; - // Witness table that `DifferentialBottom:IDifferential`. - IRWitnessTable* differentialBottomWitness = nullptr; - Dictionary<InstPair, IRInst*> differentialPairTypes; - - // References to other passes that for reverse-mode transcription. - ForwardDerivativeTranscriber *fwdDiffTranscriber; - DiffTransposePass *diffTransposePass; - DiffPropagationPass *diffPropagationPass; - DiffUnzipPass *diffUnzipPass; - - // Allocate space for the passes. - ForwardDerivativeTranscriber fwdDiffTranscriberStorage; - DiffTransposePass diffTransposePassStorage; - DiffPropagationPass diffPropagationPassStorage; - DiffUnzipPass diffUnzipPassStorage; - - - BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) - : autoDiffSharedContext(shared) - , sink(inSink) - , differentiableTypeConformanceContext(shared) - , sharedBuilder(inSharedBuilder) - , fwdDiffTranscriberStorage(shared, inSharedBuilder) - , diffTransposePassStorage(shared) - , diffPropagationPassStorage(shared) - , diffUnzipPassStorage(shared) - , fwdDiffTranscriber(&fwdDiffTranscriberStorage) - , diffTransposePass(&diffTransposePassStorage) - , diffPropagationPass(&diffPropagationPassStorage) - , diffUnzipPass(&diffUnzipPassStorage) - { } - - DiagnosticSink* getSink() - { - SLANG_ASSERT(sink); - return sink; - } - - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) { List<IRType*> newParameterTypes; IRType* diffReturnType; @@ -123,198 +51,46 @@ struct BackwardDiffTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); - SLANG_ASSERT(diffPairType); - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - auto diffType = differentiateType(&builder, diffPairType->getValueType()); - auto differentialType = builder.getDifferentialPairType(diffType, nullptr); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - return table; - } - - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* getOrCreateDiffPairType(IRInst* primalType) + InstPair BackwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst) { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - auto witness = as<IRWitnessTable>( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* differentiateType(IRBuilder* builder, IRType* origType) - { - IRInst* diffType = nullptr; - if (!orginalToTranscribed.TryGetValue(origType, diffType)) - { - diffType = _differentiateTypeImpl(builder, origType); - orginalToTranscribed[origType] = diffType; - } - return (IRType*)diffType; - } - - IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) - { - if (auto ptrType = as<IRPtrTypeBase>(origType)) - return builder->getPtrType( - origType->getOp(), - differentiateType(builder, ptrType->getValueType())); - - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = origType; - - // Special case certain compound types (PtrType, FuncType, etc..) - // otherwise try to lookup a differential definition for the given type. - // If one does not exist, then we assume it's not differentiable. - // - switch (primalType->getOp()) + switch (origInst->getOp()) { case kIROp_Param: - if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); - else if (as<IRWitnessTableType>(primalType->getDataType())) - return (IRType*)primalType; - - case kIROp_ArrayType: - { - auto primalArrayType = as<IRArrayType>(primalType); - if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) - return builder->getArrayType( - diffElementType, - primalArrayType->getElementCount()); - else - return nullptr; - } - - case kIROp_DifferentialPairType: - { - auto primalPairType = as<IRDifferentialPairType>(primalType); - return getOrCreateDiffPairType( - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); - } - - case kIROp_FuncType: - return differentiateFunctionType(builder, as<IRFuncType>(primalType)); - - case kIROp_OutType: - if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) - return builder->getOutType(diffValueType); - else - return nullptr; - - case kIROp_InOutType: - if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) - return builder->getInOutType(diffValueType); - else - return nullptr; - - case kIROp_TupleType: - { - auto tupleType = as<IRTupleType>(primalType); - List<IRType*> diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); - } + return transcribeParam(builder, as<IRParam>(origInst)); - default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); - } - } + case kIROp_Return: + return transcribeReturn(builder, as<IRReturn>(origInst)); - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) - { - // If this is a PtrType (out, inout, etc..), then create diff pair from - // value type and re-apply the appropropriate PtrType wrapper. - // - if (auto origPtrType = as<IRPtrTypeBase>(primalType)) - { - if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); - else - return nullptr; - } - auto diffType = differentiateType(builder, primalType); - if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); - return nullptr; - } + case kIROp_LookupWitness: + return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); - InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - // Do not differentiate generic type (and witness table) parameters - if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) - { - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } + case kIROp_Specialize: + return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); + case kIROp_MakeVectorFromScalar: + case kIROp_MakeTuple: + case kIROp_FloatLit: + case kIROp_IntLit: + case kIROp_VoidLit: + case kIROp_ExtractExistentialWitnessTable: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_WrapExistential: + case kIROp_MakeExistential: + case kIROp_MakeExistentialWithRTTI: + return trascribeNonDiffInst(builder, origInst); - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); + case kIROp_StructKey: + return InstPair(origInst, nullptr); } - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); + return InstPair(nullptr, nullptr); } // Returns "dp<var-name>" to use as a name hint for parameters. // If no primal name is available, returns a blank string. // - String makeDiffPairName(IRInst* origVar) + String BackwardDiffTranscriber::makeDiffPairName(IRInst* origVar) { if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { @@ -330,7 +106,7 @@ struct BackwardDiffTranscriber // result, it's useful to have a method to generate zero literals of any (arithmetic) type. // The current implementation requires that types are defined linearly. // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + IRInst* BackwardDiffTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) { if (auto diffType = differentiateType(builder, primalType)) { @@ -364,7 +140,7 @@ struct BackwardDiffTranscriber } } - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + InstPair BackwardDiffTranscriber::transposeBlock(IRBuilder* builder, IRBlock* origBlock) { IRBuilder subBuilder(builder->getSharedBuilder()); subBuilder.setInsertLoc(builder->getInsertLoc()); @@ -401,10 +177,10 @@ struct BackwardDiffTranscriber { sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); } - this->transcribeInstBackward(&subBuilder, child, sumGrad); + this->transposeInstBackward(&subBuilder, child, sumGrad); } else - this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst()); + this->transposeInstBackward(&subBuilder, child, upperGrads->getFirst()); } subBuilder.emitReturn(); @@ -412,9 +188,20 @@ struct BackwardDiffTranscriber return InstPair(diffBlock, diffBlock); } + static bool isMarkedForBackwardDifferentiation(IRInst* callable) + { + return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; + } + // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { + if (auto bwdDecor = origFunc->findDecoration<IRBackwardDerivativeDecoration>()) + return InstPair(origFunc, bwdDecor->getBackwardDerivativeFunc()); + + if (!isMarkedForBackwardDifferentiation(origFunc)) + return InstPair(nullptr, nullptr); + IRBuilder builder(inBuilder->getSharedBuilder()); builder.setInsertBefore(origFunc); @@ -450,13 +237,17 @@ struct BackwardDiffTranscriber cloneDecoration(dictDecor, diffFunc); } - auto result = InstPair(primalFunc, diffFunc); - followUpFunctionsToTranscribe.add(result); - return result; + FuncBodyTranscriptionTask task; + task.originalFunc = primalFunc; + task.resultFunc = diffFunc; + task.type = FuncBodyTranscriptionTaskType::Backward; + autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); + + return InstPair(primalFunc, diffFunc); } // Puts parameters into their own block. - void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func) + void BackwardDiffTranscriber::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func) { IRBuilder builder(inBuilder->getSharedBuilder()); @@ -491,7 +282,7 @@ struct BackwardDiffTranscriber builder.emitBranch(firstBlock); } - void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) + void BackwardDiffTranscriber::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) { IRStructType* structType = as<IRStructType>(intermediateType); if (!structType) @@ -584,7 +375,7 @@ struct BackwardDiffTranscriber } // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + InstPair BackwardDiffTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { SLANG_ASSERT(primalFunc); SLANG_ASSERT(diffFunc); @@ -592,15 +383,17 @@ struct BackwardDiffTranscriber // TODO(sai): Fill in documentation. // Generate a temporary forward derivative function as an intermediate step. - IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(builder, (IRFunc*)primalFunc).differential); + IRBuilder tempBuilder = *builder; + tempBuilder.setInsertBefore(diffFunc); + IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, (IRFunc*)primalFunc).differential); SLANG_ASSERT(fwdDiffFunc); // Transcribe the body of the primal function into it's linear (fwd-diff) form. // TODO(sai): Handle the case when we already have a user-defined fwd-derivative function. - fwdDiffTranscriber->transcribeFunc(builder, primalFunc, as<IRFunc>(fwdDiffFunc)); + fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, as<IRFunc>(fwdDiffFunc)); // Split first block into a paramter block. - this->makeParameterBlock(builder, as<IRFunc>(fwdDiffFunc)); + this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); // This steps adds a decoration to instructions that are computing the differential. // TODO: This is disabled for now because fwd-mode already adds differential decorations @@ -642,7 +435,7 @@ struct BackwardDiffTranscriber auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType); // Transpose the first block (parameter block) - transcribeParameterBlock(builder, diffFunc); + transposeParameterBlock(builder, diffFunc); builder->setInsertInto(diffFunc); @@ -663,7 +456,7 @@ struct BackwardDiffTranscriber return InstPair(primalFunc, diffFunc); } - void transcribeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) + void BackwardDiffTranscriber::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) { IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); @@ -715,7 +508,7 @@ struct BackwardDiffTranscriber builder->emitParam(dOutParamType); } - IRInst* copyParam(IRBuilder* builder, IRParam* origParam) + IRInst* BackwardDiffTranscriber::copyParam(IRBuilder* builder, IRParam* origParam) { auto primalDataType = origParam->getDataType(); @@ -737,11 +530,10 @@ struct BackwardDiffTranscriber return diffParam; } - return cloneInst(&cloneEnv, builder, origParam); } - InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith) + InstPair BackwardDiffTranscriber::copyBinaryArith(IRBuilder* builder, IRInst* origArith) { SLANG_ASSERT(origArith->getOperandCount() == 2); @@ -785,7 +577,7 @@ struct BackwardDiffTranscriber return InstPair(newInst, nullptr); } - IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) + IRInst* BackwardDiffTranscriber::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) { SLANG_ASSERT(origArith->getOperandCount() == 2); @@ -853,7 +645,7 @@ struct BackwardDiffTranscriber return nullptr; } - InstPair copyInst(IRBuilder* builder, IRInst* origInst) + InstPair BackwardDiffTranscriber::copyInst(IRBuilder* builder, IRInst* origInst) { // Handle common SSA-style operations switch (origInst->getOp()) @@ -878,7 +670,7 @@ struct BackwardDiffTranscriber return InstPair(nullptr, nullptr); } - IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) + IRInst* BackwardDiffTranscriber::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) { IRInOutType* inoutParam = as<IRInOutType>(param->getDataType()); auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType()); @@ -895,19 +687,19 @@ struct BackwardDiffTranscriber return store; } - IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) + IRInst* BackwardDiffTranscriber::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) { // Handle common SSA-style operations switch (origInst->getOp()) { case kIROp_Param: - return transcribeParamBackward(builder, as<IRParam>(origInst), grad); + return transposeParamBackward(builder, as<IRParam>(origInst), grad); case kIROp_Add: case kIROp_Mul: case kIROp_Sub: case kIROp_Div: - return transcribeBinaryArithBackward(builder, origInst, grad); + return transposeBinaryArithBackward(builder, origInst, grad); case kIROp_DifferentialPairGetPrimal: { @@ -935,191 +727,72 @@ struct BackwardDiffTranscriber return nullptr; } - -}; - -struct ReverseDerivativePass : public InstPassBase -{ - DiagnosticSink* getSink() - { - return sink; - } - - bool processModule() + InstPair BackwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) { + auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); + List<IRInst*> primalArgs; + for (UInt i = 0; i < origSpecialize->getArgCount(); i++) + { + primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i))); + } + auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType()); + auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst( + (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer()); - IRBuilder builderStorage(autodiffContext->sharedBuilder); - IRBuilder* builder = &builderStorage; + IRInst* diffBase = nullptr; + if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) + { + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } - // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by - // generating derivative code for the referenced function. + auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); + // Look for an IRBackwardDerivativeDecoration on the specialize inst. + // (Normally, this would be on the inner IRFunc, but in this case only the JVP func + // can be specialized, so we put a decoration on the IRSpecialize) // - bool modified = processReferencedFunctions(builder); - - return modified; - } - - IRInst* lookupJVPReference(IRInst* primalFunction) - { - if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) - return jvpDefinition->getForwardDerivativeFunc(); - return nullptr; - } - - // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), - // then check that the referenced function is marked correctly for differentiation. - // - bool processReferencedFunctions(IRBuilder* builder) - { - bool changed = false; - - List<IRInst*> autoDiffWorkList; - - for (;;) + if (auto backDecor = origSpecialize->findDecoration<IRBackwardDerivativeDecoration>()) { - // Collect all `ForwardDifferentiate` insts from the module. - autoDiffWorkList.clear(); - processAllInsts([&](IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_BackwardDifferentiate: - // Only process now if the operand is a materialized function. - switch (inst->getOperand(0)->getOp()) - { - case kIROp_Func: - case kIROp_Specialize: - autoDiffWorkList.add(inst); - break; - default: - break; - } - break; - default: - break; - } - }); - - if (autoDiffWorkList.getCount() == 0) - break; - - // Process collected `ForwardDifferentiate` insts and replace them with placeholders for - // differentiated functions. + auto derivativeFunc = backDecor->getBackwardDerivativeFunc(); - backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); + // Make sure this isn't itself a specialize . + SLANG_RELEASE_ASSERT(!as<IRSpecialize>(derivativeFunc)); - for (auto differentiateInst : autoDiffWorkList) - { - IRInst* baseInst = differentiateInst->getOperand(0); - if (as<IRBackwardDifferentiate>(differentiateInst)) - { - if (isMarkedForBackwardDifferentiation(baseInst)) - { - if (as<IRFunc>(baseInst)) - { - IRInst* diffFunc = - backwardTranscriberStorage - .transcribeFuncHeader(builder, (IRFunc*)baseInst) - .differential; - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - changed = true; - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } - } - } - } - - auto followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe); - for (auto task : followUpWorkList) + return InstPair(primalSpecialize, derivativeFunc); + } + else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRBackwardDerivativeDecoration>()) + { + diffBase = derivativeDecoration->getBackwardDerivativeFunc(); + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { - auto diffFunc = as<IRFunc>(task.differential); - SLANG_ASSERT(diffFunc); - auto primalFunc = as<IRFunc>(task.primal); - SLANG_ASSERT(primalFunc); - - backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc); + args.add(primalSpecialize->getArg(i)); } - - // Transcribing the function body really shouldn't produce more follow up function body work. - // However it may produce new `ForwardDifferentiate` instructions, which we collect and process - // in the next iteration. - SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0); - + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); } - return changed; - } - - // Checks decorators to see if the function should - // be differentiated (kIROp_ForwardDifferentiableDecoration) - // - bool isMarkedForBackwardDifferentiation(IRInst* callable) - { - return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; - } - - IRStringLit* getBackwardDerivativeFuncName(IRInst* func) - { - IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(func); - - IRStringLit* name = nullptr; - if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) + else if (auto diffDecor = genericInnerVal->findDecoration<IRBackwardDifferentiableDecoration>()) { - name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice()); + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); } - else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) + else { - name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice()); + return InstPair(primalSpecialize, nullptr); } - - return name; - } - - ReverseDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) : - InstPassBase(context->moduleInst->getModule()), - sink(sink), - backwardTranscriberStorage(context, context->sharedBuilder, sink), - autodiffContext(context), - pairBuilderStorage(context) - { - backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; - backwardTranscriberStorage.fwdDiffTranscriberStorage.sink = sink; - backwardTranscriberStorage.fwdDiffTranscriberStorage.autoDiffSharedContext = context; - backwardTranscriberStorage.fwdDiffTranscriberStorage.pairBuilder = &(pairBuilderStorage); } - -protected: - // A transcriber object that handles the main job of - // processing instructions while maintaining state. - // - BackwardDiffTranscriber backwardTranscriberStorage; - - // Diagnostic object from the compile request for - // error messages. - DiagnosticSink* sink; - - // Builder for dealing with differential pair types. - DifferentialPairTypeBuilder pairBuilderStorage; - - // Autodiff Shared Context - AutoDiffSharedContext* autodiffContext; -}; - -bool processReverseDerivativeCalls( - AutoDiffSharedContext* autodiffContext, - DiagnosticSink* sink, - IRReverseDerivativePassOptions const&) -{ - ReverseDerivativePass revPass(autodiffContext, sink); - bool changed = revPass.processModule(); - return changed; -} - } diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index c3d31e2a9..f9ca6110c 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -7,6 +7,10 @@ #include "slang-ir-autodiff.h" #include "slang-ir-autodiff-fwd.h" +#include "slang-ir-autodiff-transcriber-base.h" +#include "slang-ir-autodiff-propagate.h" +#include "slang-ir-autodiff-unzip.h" +#include "slang-ir-autodiff-transpose.h" namespace Slang { @@ -16,10 +20,84 @@ struct IRReverseDerivativePassOptions // Nothing for now.. }; -bool processReverseDerivativeCalls( - AutoDiffSharedContext* autodiffContext, - DiagnosticSink* sink, - IRReverseDerivativePassOptions const& options = IRReverseDerivativePassOptions()); +struct BackwardDiffTranscriber : AutoDiffTranscriberBase +{ + // Map that stores the upper gradient given an IRInst* + Dictionary<IRInst*, List<IRInst*>> upperGradients; + Dictionary<IRInst*, IRInst*> primalToDiffPair; + Dictionary<IRInst*, IRInst*> orginalToTranscribed; + + // References to other passes that for reverse-mode transcription. + ForwardDiffTranscriber* fwdDiffTranscriber; + DiffTransposePass* diffTransposePass; + DiffPropagationPass* diffPropagationPass; + DiffUnzipPass* diffUnzipPass; + + // Allocate space for the passes. + DiffTransposePass diffTransposePassStorage; + DiffPropagationPass diffPropagationPassStorage; + DiffUnzipPass diffUnzipPassStorage; + + BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink) + , diffTransposePassStorage(shared) + , diffPropagationPassStorage(shared) + , diffUnzipPassStorage(shared) + , diffTransposePass(&diffTransposePassStorage) + , diffPropagationPass(&diffPropagationPassStorage) + , diffUnzipPass(&diffUnzipPassStorage) + { } + + // Returns "dp<var-name>" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar); + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. + // + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); + + InstPair transposeBlock(IRBuilder* builder, IRBlock* origBlock); + + // Puts parameters into their own block. + void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func); + + void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType); + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc); + + void transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc); + IRInst* copyParam(IRBuilder* builder, IRParam* origParam); + + InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith); + + IRInst* transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad); + + InstPair copyInst(IRBuilder* builder, IRInst* origInst); + + IRInst* transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad); + + IRInst* transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad); + + InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); + + // Create an empty func to represent the transcribed func of `origFunc`. + virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; + + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override; + + virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override; + + virtual IROp getDifferentiableMethodDictionaryItemOp() override + { + return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + } + +}; -}
\ No newline at end of file +} diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp new file mode 100644 index 000000000..da7762908 --- /dev/null +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -0,0 +1,847 @@ +// slang-ir-autodiff-trascriber-base.cpp +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-transcriber-base.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +DiagnosticSink* AutoDiffTranscriberBase::getSink() +{ + SLANG_ASSERT(sink); + return sink; +} + +String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar) +{ + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) + { + return ("dp" + String(namehintDecoration->getName())); + } + + return String(""); +} + +void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst) +{ + if (hasDifferentialInst(origInst)) + { + if (lookupDiffInst(origInst) != diffInst) + { + SLANG_UNEXPECTED("Inconsistent differential mappings"); + } + } + else + { + instMapD.Add(origInst, diffInst); + } +} + +void AutoDiffTranscriberBase::mapPrimalInst(IRInst* origInst, IRInst* primalInst) +{ + if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) + { + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "inconsistent primal instruction for original"); + } + else + { + cloneEnv.mapOldValToNew[origInst] = primalInst; + } +} + +IRInst* AutoDiffTranscriberBase::lookupDiffInst(IRInst* origInst) +{ + return instMapD[origInst]; +} + +IRInst* AutoDiffTranscriberBase::lookupDiffInst(IRInst* origInst, IRInst* defaultInst) +{ + if (auto lookupResult = instMapD.TryGetValue(origInst)) + return *lookupResult; + return defaultInst; +} + +bool AutoDiffTranscriberBase::hasDifferentialInst(IRInst* origInst) +{ + if (!origInst) + return false; + return instMapD.ContainsKey(origInst); +} + +bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* origInst) +{ + if (as<IRGlobalValueWithCode>(origInst)) + return true; + if (origInst->parent && origInst->parent->getOp() == kIROp_Module) + return true; + return false; +} + +IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst) +{ + if (!origInst) + return nullptr; + if (shouldUseOriginalAsPrimal(origInst)) + return origInst; + return cloneEnv.mapOldValToNew[origInst]; +} + +IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) +{ + if (!origInst) + return nullptr; + return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; +} + +bool AutoDiffTranscriberBase::hasPrimalInst(IRInst* origInst) +{ + if (!origInst) + return false; + if (shouldUseOriginalAsPrimal(origInst)) + return true; + return cloneEnv.mapOldValToNew.ContainsKey(origInst); +} + +IRInst* AutoDiffTranscriberBase::findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) +{ + if (!hasDifferentialInst(origInst)) + { + transcribe(builder, origInst); + SLANG_ASSERT(hasDifferentialInst(origInst)); + } + + return lookupDiffInst(origInst); +} + +IRInst* AutoDiffTranscriberBase::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) +{ + if (!origInst) + return origInst; + + if (shouldUseOriginalAsPrimal(origInst)) + return origInst; + + if (!hasPrimalInst(origInst)) + { + transcribe(builder, origInst); + SLANG_ASSERT(hasPrimalInst(origInst)); + } + + return lookupPrimalInst(origInst); +} + +IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst) +{ + IRInst* primal = lookupPrimalInst(inst, inst); + + if (primal == inst && + !isChildInstOf(builder->getInsertLoc().getParent(), inst->getParent())) + primal = cloneInst(&cloneEnv, builder, inst); + + return primal; +} + +// Get or construct `:IDifferentiable` conformance for a DifferentiablePair. +IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRInst* inDiffPairType) +{ + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); + SLANG_ASSERT(diffPairType); + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = differentiateType(&builder, diffPairType); + + // And place it in the synthesized witness table. + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + // Record this in the context for future lookups + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + + return table; +} + +IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) +{ + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); +} + +IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType) +{ + IRBuilder builder(sharedBuilder); + if (!primalType->next) + builder.setInsertInto(primalType->parent); + else + builder.setInsertBefore(primalType->next); + + IRInst* witness = as<IRWitnessTable>( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + + if (!witness) + { + if (auto primalPairType = as<IRDifferentialPairType>(primalType)) + { + witness = getDifferentialPairWitness(primalPairType); + } + else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) + { + differentiateExtractExistentialType(&builder, extractExistential, witness); + } + } + + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); +} + +IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) +{ + return (IRType*)transcribe(builder, origType); +} + +IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRType* origType) +{ + if (auto ptrType = as<IRPtrTypeBase>(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = lookupPrimalInst(origType, origType); + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; + + case kIROp_ExtractExistentialType: + { + IRInst* wt = nullptr; + return differentiateExtractExistentialType(builder, as<IRExtractExistentialType>(primalType), wt); + } + + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + } +} + +// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. +static bool _findDifferentiableInterfaceLookupPathImpl( + HashSet<IRInst*>& processedTypes, + IRInterfaceType* idiffType, + IRInterfaceType* type, + List<IRInterfaceRequirementEntry*>& currentPath) +{ + if (processedTypes.Contains(type)) + return false; + processedTypes.Add(type); + + List<IRInterfaceRequirementEntry*> lookupKeyPath; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i)); + if (!entry) continue; + if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) + { + currentPath.add(entry); + if (wt->getConformanceType() == idiffType) + { + return true; + } + else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) + { + if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + return true; + } + currentPath.removeLast(); + } + } + return false; +} + +List<IRInterfaceRequirementEntry*> AutoDiffTranscriberBase::findDifferentiableInterfaceLookupPath( + IRInterfaceType* idiffType, + IRInterfaceType* type) +{ + List<IRInterfaceRequirementEntry*> currentPath; + HashSet<IRInst*> processedTypes; + _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + return currentPath; +} + +InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +{ + IRInst* witnessTable = nullptr; + + IRInst* origBase = origInst->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); + + IRInst* primalResult = builder->emitIntrinsicInst( + primalType, + origInst->getOp(), + 1, + &primalBase); + + // Search for IDifferentiable conformance. + auto interfaceType = as<IRInterfaceType>( + unwrapAttributedType(cast<IRWitnessTableType>(origInst->getDataType())->getConformanceType())); + if (!interfaceType) + return InstPair(primalResult, nullptr); + List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + witnessTable = primalResult; + for (auto node : lookupKeyPath) + { + witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey()); + } + return InstPair(primalResult, witnessTable); + } + return InstPair(primalResult, nullptr); +} + + +IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& outWitnessTable) +{ + outWitnessTable = nullptr; + + // Search for IDifferentiable conformance. + auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType())); + if (!interfaceType) + return nullptr; + List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + outWitnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0)); + for (auto node : lookupKeyPath) + { + outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); + } + auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); + return (IRType*)diffType; + } + return nullptr; +} + +IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* primalType) +{ + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. + // + if (auto origPtrType = as<IRPtrTypeBase>(primalType)) + { + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(primalType->getOp(), diffPairValueType); + else + return nullptr; + } + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)getOrCreateDiffPairType(primalType); + return nullptr; +} + +IRInst* AutoDiffTranscriberBase::findInterfaceRequirement(IRInterfaceType* type, IRInst* key) +{ + for (UInt i = 0; i < type->getOperandCount(); i++) + { + if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) + { + if (req->getRequirementKey() == key) + return req->getRequirementVal(); + } + } + return nullptr; +} + +InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* origParam) +{ + auto primalDataType = findOrTranscribePrimalInst(builder, origParam->getDataType()); + // Do not differentiate generic type (and witness table) parameters + if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + // Is this param a phi node or a function parameter? + auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent()); + bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); + if (isFuncParam) + { + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as<IRDifferentialPairType>(diffPairType)) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) + { + auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); + + return InstPair( + builder->emitDifferentialPairAddressPrimal(diffPairParam), + builder->emitDifferentialPairAddressDifferential( + builder->getPtrType( + kIROp_PtrType, + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), + diffPairParam)); + } + } + + auto primalInst = cloneInst(&cloneEnv, builder, origParam); + if (auto primalParam = as<IRParam>(primalInst)) + { + SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); + primalParam->removeFromParent(); + builder->getInsertLoc().getBlock()->addParam(primalParam); + } + return InstPair(primalInst, nullptr); + } + else + { + auto primal = cloneInst(&cloneEnv, builder, origParam); + IRInst* diff = nullptr; + if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) + { + diff = builder->emitParam(diffType); + } + return InstPair(primal, diff); + } +} + +InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst) +{ + auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable()); + auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey()); + auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType()); + auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey); + + auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType())); + if (!interfaceType) + { + return InstPair(primal, nullptr); + } + auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!dict) + { + return InstPair(primal, nullptr); + } + + for (auto child : dict->getChildren()) + { + if (auto item = as<IRDifferentiableMethodRequirementDictionaryItem>(child)) + { + if (item->getOp() == getDifferentiableMethodDictionaryItemOp()) + { + if (item->getOperand(0) == lookupInst->getRequirementKey()) + { + auto diffKey = item->getOperand(1); + if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) + { + auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); + return InstPair(primal, diff); + } + break; + } + } + } + } + return InstPair(primal, nullptr); +} + +// In differential computation, the 'default' differential value is always zero. +// This is a consequence of differential computing being inherently linear. As a +// result, it's useful to have a method to generate zero literals of any (arithmetic) type. +// The current implementation requires that types are defined linearly. +// +IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) +{ + if (auto diffType = differentiateType(builder, primalType)) + { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + } + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + if (!zeroMethod) + { + // if the differential type itself comes from a witness lookup, we can just lookup the + // zero method from the same witness table. + if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) + { + auto wt = lookupInterface->getWitnessTable(); + zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + } + } + SLANG_RELEASE_ASSERT(zeroMethod); + + auto emptyArgList = List<IRInst*>(); + + auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + builder->markInstAsDifferential(callInst, primalType); + + return callInst; + } + else + { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } +} + +InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* origBlock) +{ + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); + + IRInst* diffBlock = subBuilder.emitBlock(); + + // Note: for blocks, we setup the mapping _before_ + // processing the children since we could encounter + // a lookup while processing the children. + // + mapPrimalInst(origBlock, diffBlock); + mapDifferentialInst(origBlock, diffBlock); + + subBuilder.setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->transcribe(&subBuilder, param); + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->transcribe(&subBuilder, child); + + return InstPair(diffBlock, diffBlock); +} + +InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) +{ + auto primal = cloneInst(&cloneEnv, builder, origInst); + return InstPair(primal, nullptr); +} + +InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn* origReturn) +{ + IRInst* origReturnVal = origReturn->getVal(); + + auto returnDataType = (IRType*)findOrTranscribePrimalInst(builder, origReturnVal->getDataType()); + if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal)) + { + // If the return value is itself a function, generic or a struct then this + // is likely to be a generic scope. In this case, we lookup the differential + // and return that. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + + // Neither of these should be nullptr. + SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); + IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); + builder->markInstAsMixedDifferential(diffReturn, nullptr); + + return InstPair(diffReturn, diffReturn); + } + else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) + { + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + if (!diffReturnVal) + diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); + + // If the pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffReturnVal); + + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); + builder->markInstAsMixedDifferential(diffPair, pairType); + + IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); + builder->markInstAsMixedDifferential(pairReturn, pairType); + + return InstPair(pairReturn, pairReturn); + } + else + { + // If the return type is not differentiable, emit the primal value only. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + + IRInst* primalReturn = builder->emitReturn(primalReturnVal); + return InstPair(primalReturn, nullptr); + + } +} + +// Transcribe a generic definition +InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) +{ + auto innerVal = findInnerMostGenericReturnVal(origGeneric); + if (auto innerFunc = as<IRFunc>(innerVal)) + { + differentiableTypeConformanceContext.setFunc(innerFunc); + } + else if (auto funcType = as<IRFuncType>(innerVal)) + { + } + else + { + return InstPair(origGeneric, nullptr); + } + + IRGeneric* primalGeneric = origGeneric; + + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origGeneric); + + auto diffGeneric = builder.emitGeneric(); + + // Process type of generic. If the generic is a function, then it's type will also be a + // generic and this logic will transcribe that generic first before continuing with the + // function itself. + // + auto primalType = primalGeneric->getFullType(); + + IRType* diffType = nullptr; + if (primalType) + { + diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + } + + diffGeneric->setFullType(diffType); + + // Transcribe children from origFunc into diffFunc. + builder.setInsertInto(diffGeneric); + for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(&builder, block); + + return InstPair(primalGeneric, diffGeneric); +} + +IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) +{ + // If a differential intstruction is already mapped for + // this original inst, return that. + // + if (auto diffInst = lookupDiffInst(origInst, nullptr)) + { + SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. + return diffInst; + } + + // Otherwise, dispatch to the appropriate method + // depending on the op-code. + // + instsInProgress.Add(origInst); + + InstPair pair = transcribeInst(builder, origInst); + + instsInProgress.Remove(origInst); + + if (auto primalInst = pair.primal) + { + mapPrimalInst(origInst, pair.primal); + mapDifferentialInst(origInst, pair.differential); + if (pair.differential) + { + switch (pair.differential->getOp()) + { + case kIROp_Func: + case kIROp_Generic: + case kIROp_Block: + // Don't generate again for these. + // Functions already have their names generated in `transcribeFuncHeader`. + break; + default: + // Generate name hint for the inst. + if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sb; + sb << "s_diff_" << primalNameHint->getName(); + builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); + } + + // Tag the differential inst using a decoration (if it doesn't have one) + if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && + !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>()) + { + // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential + // instead. + // + builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType())); + } + + break; + } + } + return pair.differential; + } + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "failed to transcibe instruction"); + return nullptr; +} + +InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* origInst) +{ + // Handle instructions with children + switch (origInst->getOp()) + { + case kIROp_Func: + return transcribeFuncHeader(builder, as<IRFunc>(origInst)); + + case kIROp_Block: + return transcribeBlock(builder, as<IRBlock>(origInst)); + + case kIROp_Generic: + return transcribeGeneric(builder, as<IRGeneric>(origInst)); + } + + auto result = transcribeInstImpl(builder, origInst); + + if (result.primal == nullptr && result.differential == nullptr) + { + if (auto origType = as<IRType>(origInst)) + { + // If this is a generic type, transcibe the parent + // generic and derive the type from the transcribed generic's + // return value. + // + if (as<IRGeneric>(origType->getParent()->getParent()) && + findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType && + !instsInProgress.Contains(origType->getParent()->getParent())) + { + auto origGenericType = origType->getParent()->getParent(); + auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); + auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType)); + result = InstPair( + origGenericType, + innerDiffGenericType + ); + } + else + { + auto diffType = _differentiateTypeImpl(builder, origType); + IRInst* primal = maybeCloneForPrimalInst(builder, origType); + result = InstPair(primal, diffType); + } + } + } + + if (result.primal == nullptr && result.differential == nullptr) + { + // If we reach this statement, the instruction type is likely unhandled. + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::unimplemented, + "this instruction cannot be differentiated"); + } + + return result; +} + +} diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h new file mode 100644 index 000000000..8e4b7a901 --- /dev/null +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -0,0 +1,129 @@ +// slang-ir-autodiff-transcriber-base.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" +#include "slang-ir-autodiff.h" + +namespace Slang +{ + +struct AutoDiffTranscriberBase +{ + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary<IRInst*, IRInst*> instMapD; + + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet<IRInst*> instsInProgress; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + // Type conformance information. + AutoDiffSharedContext* autoDiffSharedContext; + + // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct + DifferentialPairTypeBuilder* pairBuilder; + + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + SharedIRBuilder* sharedBuilder; + + AutoDiffTranscriberBase(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : autoDiffSharedContext(shared) + , differentiableTypeConformanceContext(shared) + , sharedBuilder(inSharedBuilder) + , sink(inSink) + { + + } + + DiagnosticSink* getSink(); + + // Returns "dp<var-name>" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar); + + void mapDifferentialInst(IRInst* origInst, IRInst* diffInst); + + void mapPrimalInst(IRInst* origInst, IRInst* primalInst); + + IRInst* lookupDiffInst(IRInst* origInst); + + IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst); + + bool hasDifferentialInst(IRInst* origInst); + + bool shouldUseOriginalAsPrimal(IRInst* origInst); + + IRInst* lookupPrimalInst(IRInst* origInst); + + IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst); + + bool hasPrimalInst(IRInst* origInst); + + IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); + + IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst); + + IRInst* maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst); + + List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath( + IRInterfaceType* idiffType, IRInterfaceType* type); + + InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst); + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType); + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness); + + IRType* getOrCreateDiffPairType(IRInst* primalType); + + IRType* differentiateType(IRBuilder* builder, IRType* origType); + + IRType* differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable); + + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType); + + IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); + + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); + + InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn); + + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam); + + InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst); + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock); + + // Transcribe a generic definition + InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric); + + IRInst* transcribe(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeInst(IRBuilder* builder, IRInst* origInst); + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType); + + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) = 0; + + // Create an empty func to represent the transcribed func of `origFunc`. + virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) = 0; + + virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0; + + virtual IROp getDifferentiableMethodDictionaryItemOp() = 0; +}; + +} diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 8dfedcb94..546d5a6ec 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -64,35 +64,6 @@ struct ExtractPrimalFuncContext return intermediateType; } - // Specialize `genericToSpecialize` with the generic parameters defined in `userGeneric`. - // For example: - // ``` - // int f<T>(T a); - // ``` - // will be extended into - // ``` - // struct IntermediateFor_f<T> { T t0; } - // int f_primal<T>(T a, IntermediateFor_f<T> imm); - // ``` - // Given a user generic `f_primal<T>` and a used value parameterized on the same set of generic parameters - // `IntermediateFor_f`, `genericToSpecialize` constructs `IntermediateFor_f<T>` (using the parameter list - // from user generic). - // - IRInst* specializeWithGeneric( - IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric) - { - List<IRInst*> genArgs; - for (auto param : userGeneric->getFirstBlock()->getParams()) - { - genArgs.add(param); - } - return builder.emitSpecializeInst( - builder.getTypeKind(), - genericToSpecialize, - (UInt)genArgs.getCount(), - genArgs.getBuffer()); - } - IRInst* generatePrimalFuncType( IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType) { @@ -505,8 +476,8 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( { innerFunc = as<IRFunc>(findGenericReturnVal(genFunc)); builder.setInsertBefore(innerFunc); - specializedIntermediateType = context.specializeWithGeneric(builder, intermediateType, genFunc); - specializedPrimalFunc = context.specializeWithGeneric(builder, primalFunc, genFunc); + specializedIntermediateType = specializeWithGeneric(builder, intermediateType, genFunc); + specializedPrimalFunc = specializeWithGeneric(builder, primalFunc, genFunc); } SLANG_RELEASE_ASSERT(innerFunc); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 3d42f2922..f0ec1542e 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -5,9 +5,7 @@ namespace Slang { - -// TODO: Put into a nameless namespace. -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) { if (auto witnessTable = as<IRWitnessTable>(witness)) { @@ -41,6 +39,13 @@ bool isNoDiffType(IRType* paramType) return false; } +IRInst* lookupForwardDerivativeReference(IRInst* primalFunction) +{ + if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) + return jvpDefinition->getForwardDerivativeFunc(); + return nullptr; +} + IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key) { if (auto irStructType = as<IRStructType>(type)) @@ -277,7 +282,6 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } - AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) : moduleInst(inModuleInst) { @@ -331,8 +335,6 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde return nullptr; } - - void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; @@ -385,7 +387,6 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() } } - void stripAutoDiffDecorationsFromChildren(IRInst* parent) { for (auto inst : parent->getChildren()) @@ -398,6 +399,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_ForwardDerivativeDecoration: case kIROp_DerivativeMemberDecoration: case kIROp_DifferentiableTypeDictionaryDecoration: + case kIROp_DifferentialInstDecoration: + case kIROp_MixedDifferentialInstDecoration: decor->removeAndDeallocate(); break; default: @@ -448,6 +451,187 @@ void stripNoDiffTypeAttribute(IRModule* module) pass.processModule(); } +struct AutoDiffPass : public InstPassBase +{ + DiagnosticSink* getSink() + { + return sink; + } + + bool processModule() + { + // TODO(sai): Move this call. + forwardTranscriber.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); + + IRBuilder builderStorage(this->autodiffContext->sharedBuilder); + IRBuilder* builder = &builderStorage; + + // Process all ForwardDifferentiate and BackwardDifferentiate instructions by + // generating derivative code for the referenced function. + // + bool modified = processReferencedFunctions(builder); + + return modified; + } + + // Process all differentiate calls, and recursively generate code for forward and backward + // derivative functions. + // + bool processReferencedFunctions(IRBuilder* builder) + { + bool hasChanges = false; + for (;;) + { + bool changed = false; + List<IRInst*> autoDiffWorkList; + // Collect all `ForwardDifferentiate` insts from the module. + autoDiffWorkList.clear(); + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + // Only process now if the operand is a materialized function. + switch (inst->getOperand(0)->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + case kIROp_LookupWitness: + autoDiffWorkList.add(inst); + break; + default: + break; + } + break; + default: + break; + } + }); + + // Process collected differentiate insts and replace them with placeholders for + // differentiated functions. + + for (auto differentiateInst : autoDiffWorkList) + { + if (auto diffInst = as<IRForwardDifferentiate>(differentiateInst)) + { + IRBuilder subBuilder(*builder); + subBuilder.setInsertBefore(differentiateInst); + if (auto diffFunc = forwardTranscriber.transcribe(&subBuilder, diffInst->getBaseFn())) + { + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + changed = true; + } + } + else if (auto backDiffInst = as<IRBackwardDifferentiate>(differentiateInst)) + { + auto baseInst = backDiffInst->getBaseFn(); + if (auto diffFunc = backwardTranscriber.transcribe(builder, (IRFunc*)baseInst)) + { + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + changed = true; + } + } + } + + // Run transcription logic to generate the body of forward/backward derivatives functions. + // While doing so, we may discover new functions to differentiate, so we keep running until + // the worklist goes dry. + while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0) + { + changed = true; + auto followUpWorkList = _Move(autodiffContext->followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) + { + auto diffFunc = as<IRFunc>(task.resultFunc); + SLANG_ASSERT(diffFunc); + auto primalFunc = as<IRFunc>(task.originalFunc); + SLANG_ASSERT(primalFunc); + switch (task.type) + { + case FuncBodyTranscriptionTaskType::Forward: + forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); + break; + case FuncBodyTranscriptionTaskType::Backward: + backwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); + break; + default: + break; + } + } + } + if (!changed) + break; + hasChanges |= changed; + } + return hasChanges; + } + + IRStringLit* getDerivativeFuncName(IRInst* func, const char* postFix) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(func); + + IRStringLit* name = nullptr; + if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) + { + name = builder.getStringValue((String(linkageDecoration->getMangledName()) + postFix).getUnownedSlice()); + } + else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) + { + name = builder.getStringValue((String(namehintDecoration->getName()) + postFix).getUnownedSlice()); + } + + return name; + } + + IRStringLit* getForwardDerivativeFuncName(IRInst* func) + { + return getDerivativeFuncName(func, "_fwd_diff"); + } + + IRStringLit* getBackwardDerivativeFuncName(IRInst* func) + { + return getDerivativeFuncName(func, "_bwd_diff"); + } + + AutoDiffPass(AutoDiffSharedContext* context, DiagnosticSink* sink) : + InstPassBase(context->moduleInst->getModule()), + sink(sink), + forwardTranscriber(context, context->sharedBuilder, sink), + backwardTranscriber(context, context->sharedBuilder, sink), + pairBuilderStorage(context), + autodiffContext(context) + { + forwardTranscriber.pairBuilder = &pairBuilderStorage; + backwardTranscriber.pairBuilder = &pairBuilderStorage; + backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber; + } + +protected: + // A transcriber object that handles the main job of + // processing instructions while maintaining state. + // + ForwardDiffTranscriber forwardTranscriber; + + BackwardDiffTranscriber backwardTranscriber; + + // Diagnostic object from the compile request for + // error messages. + DiagnosticSink* sink; + + // Shared context. + AutoDiffSharedContext* autodiffContext; + + // Builder for dealing with differential pair types. + DifferentialPairTypeBuilder pairBuilderStorage; + +}; + bool processAutodiffCalls( IRModule* module, DiagnosticSink* sink, @@ -468,11 +652,9 @@ bool processAutodiffCalls( autodiffContext.sharedBuilder = &sharedBuilder; - // Process forward derivative calls. - modified |= processForwardDerivativeCalls(&autodiffContext, sink); + AutoDiffPass pass(&autodiffContext, sink); - // Process reverse derivative calls. - modified |= processReverseDerivativeCalls(&autodiffContext, sink); + modified |= pass.processModule(); return modified; } @@ -505,5 +687,4 @@ bool finalizeAutoDiffPass(IRModule* module) return false; } - } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 25cbe16f4..e0508cef7 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -13,6 +13,39 @@ namespace Slang { +template<typename P, typename D> +struct DiffInstPair +{ + P primal; + D differential; + DiffInstPair() = default; + DiffInstPair(P primal, D differential) : primal(primal), differential(differential) + {} + HashCode getHashCode() const + { + Hasher hasher; + hasher << primal << differential; + return hasher.getResult(); + } + bool operator ==(const DiffInstPair& other) const + { + return primal == other.primal && differential == other.differential; + } +}; + +typedef DiffInstPair<IRInst*, IRInst*> InstPair; + +enum class FuncBodyTranscriptionTaskType +{ + Forward, Backward, Primal +}; + +struct FuncBodyTranscriptionTask +{ + FuncBodyTranscriptionTaskType type; + IRFunc* originalFunc; + IRFunc* resultFunc; +}; struct AutoDiffSharedContext { @@ -58,6 +91,7 @@ struct AutoDiffSharedContext // bool isInterfaceAvailable = false; + List<FuncBodyTranscriptionTask> followUpFunctionsToTranscribe; AutoDiffSharedContext(IRModuleInst* inModuleInst); @@ -195,10 +229,10 @@ struct DifferentialPairTypeBuilder void stripAutoDiffDecorations(IRModule* module); -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); - bool isNoDiffType(IRType* paramType); +IRInst* lookupForwardDerivativeReference(IRInst* primalFunction); + struct IRAutodiffPassOptions { // Nothing for now... diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 5c4590abe..81b5d636a 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -129,4 +129,18 @@ IROp getTypeStyle(BaseType op) } } +IRInst* specializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric) +{ + List<IRInst*> genArgs; + for (auto param : userGeneric->getFirstBlock()->getParams()) + { + genArgs.add(param); + } + return builder.emitSpecializeInst( + builder.getTypeKind(), + genericToSpecialize, + (UInt)genArgs.getCount(), + genArgs.getBuffer()); +} + } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 385d05b28..2087ee4a7 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -44,6 +44,30 @@ inline bool isChildInstOf(IRInst* inst, IRInst* parent) return false; } + // Specialize `genericToSpecialize` with the generic parameters defined in `userGeneric`. + // For example: + // ``` + // int f<T>(T a); + // ``` + // will be extended into + // ``` + // struct IntermediateFor_f<T> { T t0; } + // int f_primal<T>(T a, IntermediateFor_f<T> imm); + // ``` + // Given a user generic `f_primal<T>` and a used value parameterized on the same set of generic parameters + // `IntermediateFor_f`, `genericToSpecialize` constructs `IntermediateFor_f<T>` (using the parameter list + // from user generic). + // +IRInst* specializeWithGeneric( + IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric); + + +inline IRInst* unwrapAttributedType(IRInst* type) +{ + while (auto attrType = as<IRAttributedType>(type)) + type = attrType->getBaseType(); + return type; +} } #endif |
