diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-18 12:37:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-18 12:37:27 -0800 |
| commit | d58e08f8237a1888ceaad53402d534679ea83b1a (patch) | |
| tree | e66838e0dc31fc12ebd7c1acecbb5060e8808366 /source/slang/slang-ir-diff-jvp.cpp | |
| parent | 0a050a439fa91b66f2020421d4fec3e60aed4112 (diff) | |
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 345 |
1 files changed, 104 insertions, 241 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 152601dbd..4ee16aafc 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1,8 +1,6 @@ // slang-ir-diff-jvp.cpp #include "slang-ir-diff-jvp.h" -#include "slang-ir.h" -#include "slang-ir-insts.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" @@ -16,154 +14,6 @@ namespace Slang { -template<typename P, typename D> -struct Pair -{ - P primal; - D differential; - Pair() = default; - Pair(P primal, D differential) : primal(primal), differential(differential) - {} - HashCode getHashCode() const - { - Hasher hasher; - hasher << primal << differential; - return hasher.getResult(); - } - bool operator ==(const Pair& other) const - { - return primal == other.primal && differential == other.differential; - } -}; - -typedef Pair<IRInst*, IRInst*> InstPair; - -struct AutoDiffSharedContext -{ - IRModuleInst* moduleInst = nullptr; - - SharedIRBuilder* sharedBuilder = nullptr; - - // A reference to the builtin IDifferentiable interface type. - // We use this to look up all the other types (and type exprs) - // that conform to a base type. - // - IRInterfaceType* differentiableInterfaceType = nullptr; - - // The struct key for the 'Differential' associated type - // defined inside IDifferential. We use this to lookup the differential - // type in the conformance table associated with the concrete type. - // - IRStructKey* differentialAssocTypeStructKey = nullptr; - - // The struct key for the witness that `Differential` associated type conforms to - // `IDifferential`. - IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; - - - // The struct key for the 'zero()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of zero() for a given type. - // - IRStructKey* zeroMethodStructKey = nullptr; - - // The struct key for the 'add()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of add() for a given type. - // - IRStructKey* addMethodStructKey = nullptr; - - IRStructKey* mulMethodStructKey = nullptr; - - - // Modules that don't use differentiable types - // won't have the IDifferentiable interface type available. - // Set to false to indicate that we are uninitialized. - // - bool isInterfaceAvailable = false; - - - AutoDiffSharedContext(IRModuleInst* inModuleInst) - : moduleInst(inModuleInst) - { - differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); - if (differentiableInterfaceType) - { - differentialAssocTypeStructKey = findDifferentialTypeStructKey(); - differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); - zeroMethodStructKey = findZeroMethodStructKey(); - addMethodStructKey = findAddMethodStructKey(); - mulMethodStructKey = findMulMethodStructKey(); - - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; - } - } - - private: - - IRInst* findDifferentiableInterface() - { - if (auto module = as<IRModuleInst>(moduleInst)) - { - for (auto globalInst : module->getGlobalInsts()) - { - // TODO: This seems like a particularly dangerous way to look for an interface. - // See if we can lower IDifferentiable to a separate IR inst. - // - if (globalInst->getOp() == kIROp_InterfaceType && - as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") - { - return globalInst; - } - } - } - return nullptr; - } - - IRStructKey* findDifferentialTypeStructKey() - { - return getIDifferentiableStructKeyAtIndex(0); - } - - IRStructKey* findDifferentialTypeWitnessStructKey() - { - return getIDifferentiableStructKeyAtIndex(1); - } - - IRStructKey* findZeroMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(2); - } - - IRStructKey* findAddMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(3); - } - - IRStructKey* findMulMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(4); - } - - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) - { - if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) - { - // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) - return as<IRStructKey>(entry->getRequirementKey()); - else - { - SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); - } - } - - return nullptr; - } -}; - namespace { @@ -189,97 +39,6 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK } -struct DifferentiableTypeConformanceContext -{ - AutoDiffSharedContext* sharedContext; - - IRGlobalValueWithCode* parentFunc = nullptr; - Dictionary<IRType*, IRInst*> differentiableWitnessDictionary; - - DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) - : sharedContext(shared) - {} - - void setFunc(IRGlobalValueWithCode* func) - { - parentFunc = func; - - auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - SLANG_RELEASE_ASSERT(decor); - - // Build lookup dictionary for type witnesses. - for (auto child = decor->getFirstChild(); child; child = child->next) - { - if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) - { - auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); - if (existingItem) - { - if (auto witness = as<IRWitnessTable>(item->getWitness())) - { - if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) - continue; - } - *existingItem = item->getWitness(); - } - else - { - differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); - } - } - } - } - - - // Lookup a witness table for the concreteType. One should exist if concreteType - // inherits (successfully) from IDifferentiable. - // - IRInst* lookUpConformanceForType(IRInst* type) - { - IRInst* foundResult = nullptr; - differentiableWitnessDictionary.TryGetValue(type, foundResult); - return foundResult; - } - - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) - { - if (auto conformance = lookUpConformanceForType(origType)) - { - return _lookupWitness(builder, conformance, key); - } - return nullptr; - } - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - switch (origType->getOp()) - { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - return origType; - } - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); - } - - IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); - } - - IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); - } - -}; - struct DifferentialPairTypeBuilder { @@ -3275,4 +3034,108 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } +AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) + : moduleInst(inModuleInst) +{ + differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); + if (differentiableInterfaceType) + { + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); + mulMethodStructKey = findMulMethodStructKey(); + + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; + } +} + +IRInst* AutoDiffSharedContext::findDifferentiableInterface() +{ + if (auto module = as<IRModuleInst>(moduleInst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + // TODO: This seems like a particularly dangerous way to look for an interface. + // See if we can lower IDifferentiable to a separate IR inst. + // + if (globalInst->getOp() == kIROp_InterfaceType && + as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") + { + return globalInst; + } + } + } + return nullptr; +} + +IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +{ + if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) + { + // Assume for now that IDifferentiable has exactly five fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); + if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) + return as<IRStructKey>(entry->getRequirementKey()); + else + { + SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); + } + } + + return nullptr; +} + +void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) +{ + parentFunc = func; + + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(decor); + + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) + { + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) + { + if (auto witness = as<IRWitnessTable>(item->getWitness())) + { + if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) + continue; + } + *existingItem = item->getWitness(); + } + else + { + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); + } + } + } +} + + +// Lookup a witness table for the concreteType. One should exist if concreteType +// inherits (successfully) from IDifferentiable. +// + +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +{ + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; +} + +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +{ + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; +} + } |
