diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 835 |
1 files changed, 589 insertions, 246 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index d0bf8f347..8a4fe23d0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -7,6 +7,7 @@ #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" #include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" // origX, primalX, diffX // origX -> primalX (cloneEnv) @@ -20,9 +21,19 @@ struct Pair { P primal; D differential; - + Pair() = default; Pair(P primal, D differential) : primal(primal), differential(differential) {} + HashCode getHashCode() const + { + Hasher hasher; + hasher << primal << differential; + return hasher.getResult(); + } + bool operator ==(const Pair& other) const + { + return primal == other.primal && differential == other.differential; + } }; typedef Pair<IRInst*, IRInst*> InstPair; @@ -43,6 +54,11 @@ struct AutoDiffSharedContext // IRStructKey* differentialAssocTypeStructKey = nullptr; + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferential`. + IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + + // The struct key for the 'zero()' associated type // defined inside IDifferential. We use this to lookup the // implementation of zero() for a given type. @@ -54,6 +70,9 @@ struct AutoDiffSharedContext // implementation of add() for a given type. // IRStructKey* addMethodStructKey = nullptr; + + IRStructKey* mulMethodStructKey = nullptr; + // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. @@ -69,8 +88,10 @@ struct AutoDiffSharedContext if (differentiableInterfaceType) { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); zeroMethodStructKey = findZeroMethodStructKey(); addMethodStructKey = findAddMethodStructKey(); + mulMethodStructKey = findMulMethodStructKey(); if (differentialAssocTypeStructKey) isInterfaceAvailable = true; @@ -103,22 +124,32 @@ struct AutoDiffSharedContext return getIDifferentiableStructKeyAtIndex(0); } - IRStructKey* findZeroMethodStructKey() + IRStructKey* findDifferentialTypeWitnessStructKey() { return getIDifferentiableStructKeyAtIndex(1); } - IRStructKey* findAddMethodStructKey() + IRStructKey* findZeroMethodStructKey() { return getIDifferentiableStructKeyAtIndex(2); } + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(3); + } + + IRStructKey* findMulMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(4); + } + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) { if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly four fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); + // Assume for now that IDifferentiable has exactly five fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) return as<IRStructKey>(entry->getRequirementKey()); else @@ -300,7 +331,16 @@ struct DifferentialPairTypeBuilder IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) { - if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) + auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); + if (baseTypeInfo.isTrivial) + { + if (key == globalPrimalKey) + return baseInst; + else + return builder->getDifferentialBottom(); + } + + if (auto basePairStructType = as<IRStructType>(baseTypeInfo.loweredType)) { return as<IRFieldExtract>(builder->emitFieldExtract( findField(basePairStructType, key)->getFieldType(), @@ -308,7 +348,7 @@ struct DifferentialPairTypeBuilder key )); } - else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) + else if (auto ptrType = as<IRPtrTypeBase>(baseTypeInfo.loweredType)) { if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) { @@ -334,7 +374,7 @@ struct DifferentialPairTypeBuilder key)); } } - else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType())) + else if (auto specializedType = as<IRSpecialize>(baseTypeInfo.loweredType)) { // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's // type, emit the specialization type. @@ -420,25 +460,64 @@ struct DifferentialPairTypeBuilder { SLANG_ASSERT(!as<IRParam>(origBaseType)); SLANG_ASSERT(diffType); - auto pairStructType = builder->createStructType(); - builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); - builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); + if (diffType->getOp() != kIROp_DifferentialBottomType) + { + auto pairStructType = builder->createStructType(); + builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); + builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); + return pairStructType; + } + return origBaseType; + } - return pairStructType; + struct LoweredPairTypeInfo + { + IRInst* loweredType; + bool isTrivial; + }; + + IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) + { + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); } - IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) + IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) { - if (pairTypeCache.ContainsKey(origBaseType)) - return pairTypeCache[origBaseType]; + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); + } - auto pairType = _createDiffPairType(builder, origBaseType, diffType); - pairTypeCache.Add(origBaseType, pairType); + LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType) + { + LoweredPairTypeInfo result = {}; + + if (pairTypeCache.TryGetValue(originalPairType, result)) + return result; + auto pairType = as<IRDifferentialPairType>(originalPairType); + if (!pairType) + { + result.isTrivial = true; + result.loweredType = originalPairType; + return result; + } + auto primalType = pairType->getValueType(); + if (as<IRParam>(primalType)) + { + result.isTrivial = false; + result.loweredType = nullptr; + return result; + } + + auto diffType = getDiffTypeFromPairType(builder, pairType); + result.loweredType = _createDiffPairType(builder, pairType->getValueType(), (IRType*)diffType); + result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); + pairTypeCache.Add(originalPairType, result); - return pairType; + return result; } - Dictionary<IRInst*, IRInst*> pairTypeCache; + Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache; IRStructKey* globalPrimalKey = nullptr; @@ -447,6 +526,8 @@ struct DifferentialPairTypeBuilder IRInst* genericDiffPairType = nullptr; List<IRInst*> generatedTypeList; + + AutoDiffSharedContext* sharedContext = nullptr; }; struct JVPTranscriber @@ -474,8 +555,15 @@ struct JVPTranscriber DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - JVPTranscriber(AutoDiffSharedContext* shared) - : differentiableTypeConformanceContext(shared) + List<InstPair> followUpFunctionsToTranscribe; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary<InstPair, IRInst*> differentialPairTypes; + + JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) + : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) {} DiagnosticSink* getSink() @@ -592,8 +680,75 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } + IRWitnessTable* getDifferentialBottomWitness() + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + SLANG_ASSERT(result); + return result; + } + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); + SLANG_ASSERT(diffPairType); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + if (result) + return result; + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + auto diffType = differentiateType(&builder, diffPairType->getValueType()); + auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + return table; + } + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* getOrCreateDiffPairType(IRInst* primalType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as<IRWitnessTable>( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + if (!witness) + witness = getDifferentialBottomWitness(); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + IRType* differentiateType(IRBuilder* builder, IRType* origType) { + IRInst* diffType = nullptr; + if (!instMapD.TryGetValue(origType, diffType)) + { + diffType = _differentiateTypeImpl(builder, origType); + instMapD[origType] = diffType; + } + return (IRType*)diffType; + } + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) + { if (auto ptrType = as<IRPtrTypeBase>(origType)) return builder->getPtrType( origType->getOp(), @@ -628,6 +783,14 @@ struct JVPTranscriber else return nullptr; } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } case kIROp_FuncType: return differentiateFunctionType(builder, as<IRFuncType>(primalType)); @@ -660,7 +823,7 @@ struct JVPTranscriber return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); } } - + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) { // If this is a PtrType (out, inout, etc..), then create diff pair from @@ -675,7 +838,7 @@ struct JVPTranscriber } auto diffType = differentiateType(builder, primalType); if (diffType) - return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType); + return (IRType*)getOrCreateDiffPairType(primalType); return nullptr; } @@ -692,7 +855,7 @@ struct JVPTranscriber if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) { - IRParam* diffPairParam = builder->emitParam(diffPairType); + IRInst* diffPairParam = builder->emitParam(diffPairType); auto diffPairVarName = makeDiffPairName(origParam); if (diffPairVarName.getLength() > 0) @@ -700,9 +863,20 @@ struct JVPTranscriber SLANG_ASSERT(diffPairParam); - return InstPair( - pairBuilder->emitPrimalFieldAccess(builder, diffPairParam), - pairBuilder->emitDiffFieldAccess(builder, diffPairParam)); + if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); } @@ -826,30 +1000,52 @@ struct JVPTranscriber InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); - - auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + auto primalPtr = lookupPrimalInst(origPtr, nullptr); + auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); - IRInst* diffLoad = nullptr; + if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) + { + // Special case load from an `out` param, which will not have corresponding `diff` and + // `primal` insts yet. + auto load = builder->emitLoad(primalPtr); + auto primalElement = builder->emitDifferentialPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + return InstPair(primalElement, diffElement); + } + auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + IRInst* diffLoad = nullptr; if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) { // Default case, we're loading from a known differential inst. diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); - return InstPair(primalLoad, diffLoad); - } - return InstPair(primalLoad, nullptr); + } + return InstPair(primalLoad, diffLoad); } InstPair transcribeStore(IRBuilder* builder, IRStore* origStore) { IRInst* origStoreLocation = origStore->getPtr(); IRInst* origStoreVal = origStore->getVal(); - - auto primalStore = cloneInst(&cloneEnv, builder, origStore); - + auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); + auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + if (!diffStoreLocation) + { + auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType()); + if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType())) + { + auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); + auto store = builder->emitStore(primalStoreLocation, valToStore); + return InstPair(store, nullptr); + } + } + + auto primalStore = cloneInst(&cloneEnv, builder, origStore); + IRInst* diffStore = nullptr; // If the stored value has a differential version, @@ -1052,8 +1248,9 @@ struct JVPTranscriber if (diffReturnType->getOp() != kIROp_VoidType) { - IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst); - IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst); + IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); + auto diffType = differentiateType(builder, origCall->getFullType()); + IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); return InstPair(primalResultValue, diffResultValue); } else @@ -1174,14 +1371,16 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeConst(IRBuilder*, IRInst* origInst) + InstPair transcribeConst(IRBuilder* builder, IRInst* origInst) { switch(origInst->getOp()) { case kIROp_FloatLit: + return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); case kIROp_VoidLit: + return InstPair(origInst, origInst); case kIROp_IntLit: - return InstPair(origInst, nullptr); + return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); } getSink()->diagnose( @@ -1245,6 +1444,14 @@ struct JVPTranscriber { if (auto diffType = differentiateType(builder, primalType)) { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + } // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); @@ -1458,40 +1665,63 @@ struct JVPTranscriber return InstPair(diffLoop, diffLoop); } - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc) + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) { - IRFunc* primalFunc = nullptr; + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalVal); + auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffPrimalVal); + auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalDiffVal); + auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffDiffVal); - differentiableTypeConformanceContext.setFunc(origFunc); + auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); + auto diffPair = builder->emitMakeDifferentialPair( + differentiateType(builder, origInst->getDataType()), + primalDiffVal, + diffDiffVal); + return InstPair(primalPair, diffPair); + } - auto oldLoc = builder->getInsertLoc(); + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) + { + SLANG_ASSERT( + origInst->getOp() == kIROp_DifferentialPairGetDifferential || + origInst->getOp() == kIROp_DifferentialPairGetPrimal); - // If this is a top-level function, there is no need to clone it - // since it is visible in all the scopes. - // Otherwise, we need to clone it in case of generic scopes. - // - // TODO(sai): Is this the correct thing to do? Can a function cloned inside a - // generic scope but is not the return value of that generic, be used within - // that scope? Or do we have to call out to the original generic specialized with - // the current generic params? - // - bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr); - if (isTopLevelFunc) - { - builder->setInsertBefore(origFunc); - primalFunc = origFunc; - } + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(primalVal); + + auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(diffVal); + + auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); + + auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType()); + IRInst* diffResultType = nullptr; + if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) + diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); else - { - // TODO(sai): this might never be called, and it might never make sense - // to call it either. Potentially remove this. - primalFunc = as<IRFunc>( - cloneInst(&cloneEnv, builder, origFunc)); - } + diffResultType = diffValPairType->getValueType(); + auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); + return InstPair(primalResult, diffResult); + } + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* builder, IRFunc* origFunc) + { + auto oldLoc = builder->getInsertLoc(); + + IRFunc* primalFunc = origFunc; + + differentiableTypeConformanceContext.setFunc(origFunc); + + builder->setInsertBefore(origFunc); + primalFunc = origFunc; auto diffFunc = builder->createFunc(); - + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); IRType* diffFuncType = this->differentiateFunctionType( builder, @@ -1505,10 +1735,33 @@ struct JVPTranscriber newNameSb << "s_jvp_" << originalName; builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - + builder->addForwardDerivativeDecoration(origFunc, diffFunc); + + // Mark the generated derivative function itself as differentiable. + builder->addForwardDifferentiableDecoration(diffFunc); + + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + cloneDecoration(dictDecor, diffFunc); + } + + // Reset builder position + builder->setInsertLoc(oldLoc); + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + { + auto oldLoc = builder->getInsertLoc(); + + differentiableTypeConformanceContext.setFunc(primalFunc); // Transcribe children from origFunc into diffFunc builder->setInsertInto(diffFunc); - for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock()) + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) this->transcribe(builder, block); // Reset builder position @@ -1685,6 +1938,11 @@ struct JVPTranscriber case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); + case kIROp_MakeDifferentialPair: + return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + return transcribeDifferentialPairGetElement(builder, origInst); } // If none of the cases have been hit, check if the instruction is a @@ -1722,7 +1980,7 @@ struct JVPTranscriber switch (origInst->getOp()) { case kIROp_Func: - return transcribeFunc(builder, as<IRFunc>(origInst)); + return transcribeFuncHeader(builder, as<IRFunc>(origInst)); case kIROp_Block: return transcribeBlock(builder, as<IRBlock>(origInst)); @@ -1741,45 +1999,7 @@ struct JVPTranscriber } }; -struct IRWorkQueue -{ - // Work list to hold the active set of insts whose children - // need to be looked at. - // - List<IRInst*> workList; - HashSet<IRInst*> workListSet; - - void push(IRInst* inst) - { - if(!inst) return; - if(workListSet.Contains(inst)) return; - - workList.add(inst); - workListSet.Add(inst); - } - - IRInst* pop() - { - if (workList.getCount() != 0) - { - IRInst* topItem = workList.getFirst(); - // TODO(Sai): Repeatedly calling removeAt() can be really slow. - // Consider a specialized data structure or using removeLast() - // - workList.removeAt(0); - workListSet.Remove(topItem); - return topItem; - } - return nullptr; - } - - IRInst* peek() - { - return workList.getFirst(); - } -}; - -struct JVPDerivativeContext +struct JVPDerivativeContext : public InstPassBase { DiagnosticSink* getSink() @@ -1795,6 +2015,7 @@ struct JVPDerivativeContext // SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; @@ -1809,8 +2030,12 @@ struct JVPDerivativeContext // IRDifferentialPairGetPrimal with 'primal' field access, and // IRMakeDifferentialPair with an IRMakeStruct. // + modified |= simplifyDifferentialBottomType(builder); + modified |= processPairTypes(builder, module->getModuleInst()); - + + modified |= eliminateDifferentialBottomType(builder); + return modified; } @@ -1826,121 +2051,92 @@ struct JVPDerivativeContext // bool processReferencedFunctions(IRBuilder* builder) { - IRWorkQueue* workQueue = &(workQueueStorage); + List<IRForwardDifferentiate*> autoDiffWorkList; - // Put the top-level inst into the queue. - workQueue->push(module->getModuleInst()); - - // Keep processing items until the queue is complete. - while (IRInst* workItem = workQueue->pop()) - { - for(auto child = workItem->getFirstChild(); child; child = child->getNextInst()) + for (;;) + { + // Collect all `ForwardDifferentiate` insts from the module. + autoDiffWorkList.clear(); + processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst) { - // Either the child instruction has more children (func/block etc..) - // and we add it to the work list for further processing, or - // it's an ordinary inst in which case we check if it's a ForwardDifferentiate - // instruction. - // - if (child->getFirstChild() != nullptr) - workQueue->push(child); - - if (auto jvpDiffInst = as<IRForwardDifferentiate>(child)) - { - auto baseInst = jvpDiffInst->getBaseFn(); + autoDiffWorkList.add(fwdDiffInst); + }); - IRGlobalValueWithCode* baseFunction = nullptr; + if (autoDiffWorkList.getCount() == 0) + break; - if (auto specializeInst = as<IRSpecialize>(baseInst)) - { - // Certain specialize insts come with a derivative - // reference attached. Skip such instructions. - // - if (lookupJVPReference(specializeInst)) continue; - } - else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst)) + // Process collected `ForwardDifferentiate` insts and replace them with placeholders for + // differentiated functions. + transcriberStorage.followUpFunctionsToTranscribe.clear(); + + for (auto fwdDiffInst : autoDiffWorkList) + { + auto baseInst = fwdDiffInst->getBaseFn(); + if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst)) + { + if (auto existingDiffFunc = lookupJVPReference(baseFunction)) { - baseFunction = globalValWithCode; + fwdDiffInst->replaceUsesWith(existingDiffFunc); + fwdDiffInst->removeAndDeallocate(); } - - SLANG_ASSERT(baseFunction); - - // If the JVP Reference already exists, no need to - // differentiate again. - // - if (lookupJVPReference(baseFunction)) continue; - - if (isMarkedForForwardDifferentiation(baseFunction)) + else if (isMarkedForForwardDifferentiation(baseFunction)) { if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) { - IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction); + IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction); SLANG_ASSERT(diffFunc); - builder->addForwardDerivativeDecoration(baseFunction, diffFunc); - workQueue->push(diffFunc); - } + fwdDiffInst->replaceUsesWith(diffFunc); + fwdDiffInst->removeAndDeallocate(); + } else { // TODO(Sai): This would probably be better with a more specific // error code. - getSink()->diagnose(jvpDiffInst->sourceLoc, + getSink()->diagnose(fwdDiffInst->sourceLoc, Diagnostics::internalCompilerError, "Unexpected instruction. Expected func or generic"); } } - else + else { // TODO(Sai): This would probably be better with a more specific // error code. - getSink()->diagnose(jvpDiffInst->sourceLoc, + getSink()->diagnose(fwdDiffInst->sourceLoc, Diagnostics::internalCompilerError, "Cannot differentiate functions not marked for differentiation"); } } } - } - - return true; - } - - IRInst* lowerPairType(IRBuilder* builder, IRType* type) - { - - if (auto pairType = as<IRDifferentialPairType>(type)) - { - builder->setInsertBefore(pairType); - - if (!as<IRType>(pairType->getValueType())) + // Actually synthesize the derivatives. + List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) { - return nullptr; - } - auto witness = pairType->getWitness(); - auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey); - if (!diffType) - { - return nullptr; + auto diffFunc = as<IRFunc>(task.differential); + SLANG_ASSERT(diffFunc); + auto primalFunc = as<IRFunc>(task.primal); + SLANG_ASSERT(primalFunc); + + transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); } - auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( - builder, - pairType->getValueType(), - (IRType*)(diffType)); - pairType->replaceUsesWith(diffPairStructType); - pairType->removeAndDeallocate(); + // Transcribing the function body really shouldn't produce more follow up function body work. + // However it may produce new `ForwardDifferentiate` instructions, which we collect and process + // in the next iteration. + SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0); - return diffPairStructType; - } - else if (auto loweredStructType = as<IRStructType>(type)) - { - // Already lowered to struct. - return loweredStructType; - } - else if (auto specializedStructType = as<IRSpecialize>(type)) - { - // Already lowered to specialized struct. - return specializedStructType; } - - return nullptr; + return true; + } + + IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) + { + builder->setInsertBefore(pairType); + auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType( + builder, + pairType); + if (isTrivial) + *isTrivial = loweredPairTypeInfo.isTrivial; + return loweredPairTypeInfo.loweredType; } IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) @@ -1948,19 +2144,24 @@ struct JVPDerivativeContext if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { - if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType())) + bool isTrivial = false; + auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType()); + if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) { builder->setInsertBefore(makePairInst); - - List<IRInst*> operands; - operands.add(makePairInst->getPrimalValue()); - operands.add(makePairInst->getDifferentialValue()); - - auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); - makePairInst->replaceUsesWith(makeStructInst); + IRInst* result = nullptr; + if (isTrivial) + { + result = makePairInst->getPrimalValue(); + } + else + { + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() }; + result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + } + makePairInst->replaceUsesWith(result); makePairInst->removeAndDeallocate(); - - return makeStructInst; + return result; } } @@ -1971,11 +2172,11 @@ struct JVPDerivativeContext { if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - if (lowerPairType(builder, getDiffInst->getBase()->getDataType())) + if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr)) { builder->setInsertBefore(getDiffInst); - - auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); + IRInst* diffFieldExtract = nullptr; + diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); getDiffInst->replaceUsesWith(diffFieldExtract); getDiffInst->removeAndDeallocate(); return diffFieldExtract; @@ -1983,14 +2184,14 @@ struct JVPDerivativeContext } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - if (lowerPairType(builder, getPrimalInst->getBase()->getDataType())) + if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr)) { builder->setInsertBefore(getPrimalInst); - auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + IRInst* primalFieldExtract = nullptr; + primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); getPrimalInst->replaceUsesWith(primalFieldExtract); getPrimalInst->removeAndDeallocate(); - return primalFieldExtract; } } @@ -2001,40 +2202,195 @@ struct JVPDerivativeContext bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; + // Hoist all pair types to global scope when possible. + auto moduleInst = module->getModuleInst(); + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) + { + if (originalPairType->parent != moduleInst) + { + originalPairType->removeFromParent(); + ShortList<IRInst*> operands; + for (UInt i = 0; i < originalPairType->getOperandCount(); i++) + { + operands.add(originalPairType->getOperand(i)); + } + auto newPairType = builder->findOrEmitHoistableInst( + originalPairType->getFullType(), + originalPairType->getOp(), + originalPairType->getOperandCount(), + operands.getArrayView().getBuffer()); + originalPairType->replaceUsesWith(newPairType); + originalPairType->removeAndDeallocate(); + } + }); - for (auto child = instWithChildren->getFirstChild(); child; ) - { - // Make sure the builder is at the right level. - builder->setInsertInto(instWithChildren); - - auto nextChild = child->getNextInst(); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - switch (child->getOp()) + processAllInsts([&](IRInst* inst) { - case kIROp_DifferentialPairType: - lowerPairType(builder, as<IRType>(child)); - break; - + // Make sure the builder is at the right level. + builder->setInsertInto(instWithChildren); + + switch (inst->getOp()) + { case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: - lowerPairAccess(builder, child); + lowerPairAccess(builder, inst); + modified = true; break; - + case kIROp_MakeDifferentialPair: - lowerMakePair(builder, child); + lowerMakePair(builder, inst); + modified = true; break; - + default: - if (child->getFirstChild()) - modified = processPairTypes(builder, child) | modified; - } + break; + } + }); - child = nextChild; + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + { + if (auto loweredType = lowerPairType(builder, inst)) + { + inst->replaceUsesWith(loweredType); + inst->removeAndDeallocate(); + } + }); + return modified; + } + + bool simplifyDifferentialBottomType(IRBuilder* builder) + { + bool modified = false; + auto diffBottom = builder->getDifferentialBottom(); + + bool changed = true; + List<IRUse*> uses; + while (changed) + { + changed = false; + // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`. + processAllInsts([&](IRInst* inst) + { + if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType) + { + if (inst != diffBottom) + { + inst->replaceUsesWith(diffBottom); + inst->removeAndDeallocate(); + modified = true; + } + } + }); + // Go through all uses of diffBottom and run simplification. + processAllInsts([&](IRInst* inst) + { + if (!inst->hasUses()) + return; + + builder->setInsertBefore(inst); + IRInst* valueToReplace = nullptr; + switch (inst->getOp()) + { + case kIROp_Store: + if (as<IRStore>(inst)->getVal() == diffBottom) + { + inst->removeAndDeallocate(); + changed = true; + } + return; + case kIROp_MakeDifferentialPair: + // Our simplification could lead to a situation where + // bottom is used to make a pair that has a non-bottom differential type, + // in this case we should use zero instead. + if (inst->getOperand(1) == diffBottom) + { + // Only apply if we are the second operand. + auto pairType = as<IRDifferentialPairType>(inst->getDataType()); + if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType) + { + auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType()); + inst->setOperand(1, zero); + changed = true; + } + } + return; + case kIROp_DifferentialPairGetDifferential: + if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) + { + valueToReplace = inst->getOperand(0)->getOperand(1); + } + break; + case kIROp_DifferentialPairGetPrimal: + if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) + { + valueToReplace = inst->getOperand(0)->getOperand(0); + } + break; + case kIROp_Add: + if (inst->getOperand(0) == diffBottom) + { + valueToReplace = inst->getOperand(1); + } + else if (inst->getOperand(1) == diffBottom) + { + valueToReplace = inst->getOperand(0); + } + break; + case kIROp_Sub: + if (inst->getOperand(0) == diffBottom) + { + // If left is bottom, and right is not bottom, then we should return -right. + // However we can't possibly run into that case since both side of - operator + // must be at the same order of differentiation. + valueToReplace = diffBottom; + } + else if (inst->getOperand(1) == diffBottom) + { + valueToReplace = inst->getOperand(0); + } + break; + case kIROp_Mul: + case kIROp_Div: + if (inst->getOperand(0) == diffBottom) + { + valueToReplace = diffBottom; + } + else if (inst->getOperand(1) == diffBottom) + { + valueToReplace = diffBottom; + } + break; + default: + break; + } + if (valueToReplace) + { + inst->replaceUsesWith(valueToReplace); + changed = true; + } + }); + modified |= changed; } return modified; } + bool eliminateDifferentialBottomType(IRBuilder* builder) + { + simplifyDifferentialBottomType(builder); + + bool modified = false; + auto diffBottom = builder->getDifferentialBottom(); + auto diffBottomType = diffBottom->getDataType(); + diffBottom->replaceUsesWith(builder->getVoidValue()); + diffBottom->removeAndDeallocate(); + diffBottomType->replaceUsesWith(builder->getVoidType()); + + return modified; + } + // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // @@ -2074,27 +2430,18 @@ struct JVPDerivativeContext } JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : - module(module), + InstPassBase(module), sink(sink), autoDiffSharedContextStorage(module->getModuleInst()), - transcriberStorage(&autoDiffSharedContextStorage) + transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage) { + pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage; transcriberStorage.sink = sink; transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); transcriberStorage.pairBuilder = &(pairBuilderStorage); } - protected: - - // This type passes over the module and generates - // forward-mode derivative versions of functions - // that are explicitly marked for it. - // - IRModule* module; - - // Shared builder state for our derivative passes. - SharedIRBuilder sharedBuilderStorage; - +protected: // A transcriber object that handles the main job of // processing instructions while maintaining state. // @@ -2104,10 +2451,6 @@ struct JVPDerivativeContext // error messages. DiagnosticSink* sink; - // Work queue to hold a stream of instructions that need - // to be checked for references to derivative functions. - IRWorkQueue workQueueStorage; - // Context to find and manage the witness tables for types // implementing `IDifferentiable` AutoDiffSharedContext autoDiffSharedContextStorage; |
