summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-18 12:37:27 -0800
committerGitHub <noreply@github.com>2022-11-18 12:37:27 -0800
commitd58e08f8237a1888ceaad53402d534679ea83b1a (patch)
treee66838e0dc31fc12ebd7c1acecbb5060e8808366 /source/slang/slang-ir-diff-jvp.cpp
parent0a050a439fa91b66f2020421d4fec3e60aed4112 (diff)
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp345
1 files changed, 104 insertions, 241 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 152601dbd..4ee16aafc 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1,8 +1,6 @@
// slang-ir-diff-jvp.cpp
#include "slang-ir-diff-jvp.h"
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
@@ -16,154 +14,6 @@
namespace Slang
{
-template<typename P, typename D>
-struct Pair
-{
- P primal;
- D differential;
- Pair() = default;
- Pair(P primal, D differential) : primal(primal), differential(differential)
- {}
- HashCode getHashCode() const
- {
- Hasher hasher;
- hasher << primal << differential;
- return hasher.getResult();
- }
- bool operator ==(const Pair& other) const
- {
- return primal == other.primal && differential == other.differential;
- }
-};
-
-typedef Pair<IRInst*, IRInst*> InstPair;
-
-struct AutoDiffSharedContext
-{
- IRModuleInst* moduleInst = nullptr;
-
- SharedIRBuilder* sharedBuilder = nullptr;
-
- // A reference to the builtin IDifferentiable interface type.
- // We use this to look up all the other types (and type exprs)
- // that conform to a base type.
- //
- IRInterfaceType* differentiableInterfaceType = nullptr;
-
- // The struct key for the 'Differential' associated type
- // defined inside IDifferential. We use this to lookup the differential
- // type in the conformance table associated with the concrete type.
- //
- IRStructKey* differentialAssocTypeStructKey = nullptr;
-
- // The struct key for the witness that `Differential` associated type conforms to
- // `IDifferential`.
- IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
-
-
- // The struct key for the 'zero()' associated type
- // defined inside IDifferential. We use this to lookup the
- // implementation of zero() for a given type.
- //
- IRStructKey* zeroMethodStructKey = nullptr;
-
- // The struct key for the 'add()' associated type
- // defined inside IDifferential. We use this to lookup the
- // implementation of add() for a given type.
- //
- IRStructKey* addMethodStructKey = nullptr;
-
- IRStructKey* mulMethodStructKey = nullptr;
-
-
- // Modules that don't use differentiable types
- // won't have the IDifferentiable interface type available.
- // Set to false to indicate that we are uninitialized.
- //
- bool isInterfaceAvailable = false;
-
-
- AutoDiffSharedContext(IRModuleInst* inModuleInst)
- : moduleInst(inModuleInst)
- {
- differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
- if (differentiableInterfaceType)
- {
- differentialAssocTypeStructKey = findDifferentialTypeStructKey();
- differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
- zeroMethodStructKey = findZeroMethodStructKey();
- addMethodStructKey = findAddMethodStructKey();
- mulMethodStructKey = findMulMethodStructKey();
-
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
- }
- }
-
- private:
-
- IRInst* findDifferentiableInterface()
- {
- if (auto module = as<IRModuleInst>(moduleInst))
- {
- for (auto globalInst : module->getGlobalInsts())
- {
- // TODO: This seems like a particularly dangerous way to look for an interface.
- // See if we can lower IDifferentiable to a separate IR inst.
- //
- if (globalInst->getOp() == kIROp_InterfaceType &&
- as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
- {
- return globalInst;
- }
- }
- }
- return nullptr;
- }
-
- IRStructKey* findDifferentialTypeStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(0);
- }
-
- IRStructKey* findDifferentialTypeWitnessStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(1);
- }
-
- IRStructKey* findZeroMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(2);
- }
-
- IRStructKey* findAddMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(3);
- }
-
- IRStructKey* findMulMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(4);
- }
-
- IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
- {
- if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
- {
- // Assume for now that IDifferentiable has exactly five fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
- return as<IRStructKey>(entry->getRequirementKey());
- else
- {
- SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
- }
- }
-
- return nullptr;
- }
-};
-
namespace
{
@@ -189,97 +39,6 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
}
-struct DifferentiableTypeConformanceContext
-{
- AutoDiffSharedContext* sharedContext;
-
- IRGlobalValueWithCode* parentFunc = nullptr;
- Dictionary<IRType*, IRInst*> differentiableWitnessDictionary;
-
- DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
- : sharedContext(shared)
- {}
-
- void setFunc(IRGlobalValueWithCode* func)
- {
- parentFunc = func;
-
- auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(decor);
-
- // Build lookup dictionary for type witnesses.
- for (auto child = decor->getFirstChild(); child; child = child->next)
- {
- if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
- {
- auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
- if (existingItem)
- {
- if (auto witness = as<IRWitnessTable>(item->getWitness()))
- {
- if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
- continue;
- }
- *existingItem = item->getWitness();
- }
- else
- {
- differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
- }
- }
- }
- }
-
-
- // Lookup a witness table for the concreteType. One should exist if concreteType
- // inherits (successfully) from IDifferentiable.
- //
- IRInst* lookUpConformanceForType(IRInst* type)
- {
- IRInst* foundResult = nullptr;
- differentiableWitnessDictionary.TryGetValue(type, foundResult);
- return foundResult;
- }
-
- IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
- {
- if (auto conformance = lookUpConformanceForType(origType))
- {
- return _lookupWitness(builder, conformance, key);
- }
- return nullptr;
- }
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- switch (origType->getOp())
- {
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return origType;
- }
- return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
- }
-
- IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
- }
-
- IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
- }
-
-};
-
struct DifferentialPairTypeBuilder
{
@@ -3275,4 +3034,108 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
+AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
+ : moduleInst(inModuleInst)
+{
+ differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
+ if (differentiableInterfaceType)
+ {
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
+ mulMethodStructKey = findMulMethodStructKey();
+
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
+ }
+}
+
+IRInst* AutoDiffSharedContext::findDifferentiableInterface()
+{
+ if (auto module = as<IRModuleInst>(moduleInst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // TODO: This seems like a particularly dangerous way to look for an interface.
+ // See if we can lower IDifferentiable to a separate IR inst.
+ //
+ if (globalInst->getOp() == kIROp_InterfaceType &&
+ as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
+ {
+ return globalInst;
+ }
+ }
+ }
+ return nullptr;
+}
+
+IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
+{
+ if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
+ {
+ // Assume for now that IDifferentiable has exactly five fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
+ if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
+ return as<IRStructKey>(entry->getRequirementKey());
+ else
+ {
+ SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
+ }
+ }
+
+ return nullptr;
+}
+
+void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
+{
+ parentFunc = func;
+
+ auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(decor);
+
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
+ {
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
+ {
+ if (auto witness = as<IRWitnessTable>(item->getWitness()))
+ {
+ if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
+ continue;
+ }
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+ }
+ }
+}
+
+
+// Lookup a witness table for the concreteType. One should exist if concreteType
+// inherits (successfully) from IDifferentiable.
+//
+
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+{
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+{
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
+}
+
}