summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-25 14:48:01 -0800
committerGitHub <noreply@github.com>2023-01-25 14:48:01 -0800
commitaa6814be1f7dea20597ae34d477e79e53d4a543f (patch)
tree15b8ad69e2c4169e12a0ad6e970fe511daa4beb7
parentae11538f5d667b11d3b3191a827093f3727eed1b (diff)
Cleanup IR representation of interface member derivative. (#2610)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h4
-rw-r--r--source/slang/slang-ir-autodiff-rev.h12
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp27
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp20
-rw-r--r--source/slang/slang-ir-inst-defs.h13
-rw-r--r--source/slang/slang-ir-insts.h20
-rw-r--r--source/slang/slang-lower-to-ir.cpp32
8 files changed, 38 insertions, 92 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index f8186a96e..53577f40e 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -97,9 +97,9 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;
- virtual IROp getDifferentiableMethodDictionaryItemOp() override
+ virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
{
- return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ return kIROp_ForwardDerivativeDecoration;
}
};
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 228bcf588..7aa6c2441 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -100,9 +100,9 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) = 0;
virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) = 0;
- virtual IROp getDifferentiableMethodDictionaryItemOp() override
+ virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
{
- return kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ return kIROp_BackwardDerivativeDecoration;
}
};
@@ -130,6 +130,10 @@ struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase
{
builder->addBackwardDerivativePrimalDecoration(inst, diffFunc);
}
+ virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
+ {
+ return kIROp_BackwardDerivativePrimalDecoration;
+ }
};
struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
@@ -164,6 +168,10 @@ struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
{
builder->addBackwardDerivativePropagateDecoration(inst, diffFunc);
}
+ virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
+ {
+ return kIROp_BackwardDerivativePropagateDecoration;
+ }
};
// A backward derivative function combines both primal + propagate functions and accepts no
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 89adbe6a0..f43206333 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -579,30 +579,19 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
{
return InstPair(primal, nullptr);
}
- auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
- if (!dict)
+ auto decor =
+ lookupInst->getRequirementKey()->findDecorationImpl(
+ getInterfaceRequirementDerivativeDecorationOp());
+ if (!decor)
{
return InstPair(primal, nullptr);
}
- for (auto child : dict->getChildren())
+ auto diffKey = decor->getOperand(0);
+ if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
{
- if (auto item = as<IRDifferentiableMethodRequirementDictionaryItem>(child))
- {
- if (item->getOp() == getDifferentiableMethodDictionaryItemOp())
- {
- if (item->getOperand(0) == lookupInst->getRequirementKey())
- {
- auto diffKey = item->getOperand(1);
- if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
- {
- auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
- return InstPair(primal, diff);
- }
- break;
- }
- }
- }
+ auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
+ return InstPair(primal, diff);
}
return InstPair(primal, nullptr);
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index b0397069b..a870dc815 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -147,7 +147,7 @@ struct AutoDiffTranscriberBase
virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0;
- virtual IROp getDifferentiableMethodDictionaryItemOp() = 0;
+ virtual IROp getInterfaceRequirementDerivativeDecorationOp() = 0;
};
}
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index f8d70c8ed..8cefa6a04 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -131,22 +131,10 @@ public:
return true;
if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType)
return true;
- auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
- if (!dictDecor)
- return false;
- for (auto child : dictDecor->getChildren())
- {
- if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child))
- {
- if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
- {
- if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward)
- return true;
- if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward)
- return true;
- }
- }
- }
+ if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>())
+ return true;
+ if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRForwardDerivativeDecoration>())
+ return level == DifferentiableLevel::Forward;
}
for (; func; func = func->parent)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c2a1886fb..817edaa83 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -776,9 +776,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/* Differentiable Type Dictionary */
INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT)
- /// Decorates an interface type and stores the mapping from a normal function requirement key to its derivative requirement key.
- INST(DifferentiableMethodRequirementDictionaryDecoration, DifferentiableMethodRequirementDictionaryDecoration, 0, PARENT)
-
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
@@ -899,16 +896,6 @@ INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDicti
/* Differentiable Type Dictionary */
INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
-/* DifferentiableMethodRequirementDictionaryItem */
- INST(ForwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0)
- INST(BackwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0)
- INST(BackwardDifferentiablePrimalMethodRequirementDictionaryItem, DifferentiablePrimalMethodRequirementDictionaryItem, 0, 0)
- INST(BackwardDifferentiablePropagateMethodRequirementDictionaryItem, DifferentiablePropagateMethodRequirementDictionaryItem, 0, 0)
- INST(BackwardDifferentiableIntermediateTypeRequirementDictionaryItem, DifferentiableIntermediateTypeRequirementDictionaryItem, 0, 0)
-
-
-INST_RANGE(DifferentiableMethodRequirementDictionaryItem, ForwardDifferentiableMethodRequirementDictionaryItem, BackwardDifferentiableMethodRequirementDictionaryItem)
-
#undef PARENT
#undef USE_OTHER
#undef INST_RANGE
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 4887b1c79..10c490f3c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -825,26 +825,6 @@ struct IRDifferentiableTypeDictionaryDecoration : IRDecoration
IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration)
};
-struct IRDifferentiableMethodRequirementDictionaryDecoration : IRDecoration
-{
- IR_LEAF_ISA(DifferentiableMethodRequirementDictionaryDecoration)
-};
-
-struct IRDifferentiableMethodRequirementDictionaryItem : IRInst
-{
- IR_PARENT_ISA(DifferentiableMethodRequirementDictionaryItem)
-};
-
-struct IRForwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem
-{
- IR_LEAF_ISA(ForwardDifferentiableMethodRequirementDictionaryItem)
-};
-
-struct IRBackwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem
-{
- IR_LEAF_ISA(BackwardDifferentiableMethodRequirementDictionaryItem)
-};
-
// An instruction that specializes another IR value
// (representing a generic) to a particular set of generic arguments
// (instructions representing types, witness tables, etc.)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d0527eef8..605ac62db 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6846,36 +6846,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo::simple(assocType);
}
- void insertRequirementKeyAssociation(IRInterfaceType* interfaceType, Decl* requirementDecl, IRInst* originalKey, IRInst* associatedKey)
+ void insertRequirementKeyAssociation(Decl* requirementDecl, IRInst* originalKey, IRInst* associatedKey)
{
- auto decor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
- if (!decor)
- {
- decor =
- (IRDifferentiableMethodRequirementDictionaryDecoration*)
- context->irBuilder->addDecoration(
- interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration);
- }
- IROp op = kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ IROp op = kIROp_Nop;
if (as<BackwardDerivativeRequirementDecl>(requirementDecl))
{
- op = kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ op = kIROp_BackwardDerivativeDecoration;
}
else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl))
{
- op = kIROp_BackwardDifferentiablePropagateMethodRequirementDictionaryItem;
+ op = kIROp_BackwardDerivativePropagateDecoration;
}
else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl))
{
- op = kIROp_BackwardDifferentiablePrimalMethodRequirementDictionaryItem;
+ op = kIROp_BackwardDerivativePrimalDecoration;
+ }
+ else if (as<ForwardDerivativeRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_ForwardDerivativeDecoration;
}
- else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(requirementDecl))
+ else
{
- op = kIROp_BackwardDifferentiableIntermediateTypeRequirementDictionaryItem;
+ return;
}
- IRInst* args[] = {originalKey, associatedKey};
- auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args);
- assoc->insertAtEnd(decor);
+ context->irBuilder->addDecoration(originalKey, op, associatedKey);
}
LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl)
@@ -7005,7 +6999,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
{
auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl);
- insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey);
+ insertRequirementKeyAssociation(diffDecl->referencedDecl, requirementKey, diffKey);
}
}
// Add lowered requirement entry to current decl mapping to prevent