diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-25 14:48:01 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-25 14:48:01 -0800 |
| commit | aa6814be1f7dea20597ae34d477e79e53d4a543f (patch) | |
| tree | 15b8ad69e2c4169e12a0ad6e970fe511daa4beb7 | |
| parent | ae11538f5d667b11d3b3191a827093f3727eed1b (diff) | |
Cleanup IR representation of interface member derivative. (#2610)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 20 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 32 |
8 files changed, 38 insertions, 92 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index f8186a96e..53577f40e 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -97,9 +97,9 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override; - virtual IROp getDifferentiableMethodDictionaryItemOp() override + virtual IROp getInterfaceRequirementDerivativeDecorationOp() override { - return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + return kIROp_ForwardDerivativeDecoration; } }; diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 228bcf588..7aa6c2441 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -100,9 +100,9 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) = 0; virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) = 0; - virtual IROp getDifferentiableMethodDictionaryItemOp() override + virtual IROp getInterfaceRequirementDerivativeDecorationOp() override { - return kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + return kIROp_BackwardDerivativeDecoration; } }; @@ -130,6 +130,10 @@ struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase { builder->addBackwardDerivativePrimalDecoration(inst, diffFunc); } + virtual IROp getInterfaceRequirementDerivativeDecorationOp() override + { + return kIROp_BackwardDerivativePrimalDecoration; + } }; struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase @@ -164,6 +168,10 @@ struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase { builder->addBackwardDerivativePropagateDecoration(inst, diffFunc); } + virtual IROp getInterfaceRequirementDerivativeDecorationOp() override + { + return kIROp_BackwardDerivativePropagateDecoration; + } }; // A backward derivative function combines both primal + propagate functions and accepts no diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 89adbe6a0..f43206333 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -579,30 +579,19 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui { return InstPair(primal, nullptr); } - auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); - if (!dict) + auto decor = + lookupInst->getRequirementKey()->findDecorationImpl( + getInterfaceRequirementDerivativeDecorationOp()); + if (!decor) { return InstPair(primal, nullptr); } - for (auto child : dict->getChildren()) + auto diffKey = decor->getOperand(0); + if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) { - if (auto item = as<IRDifferentiableMethodRequirementDictionaryItem>(child)) - { - if (item->getOp() == getDifferentiableMethodDictionaryItemOp()) - { - if (item->getOperand(0) == lookupInst->getRequirementKey()) - { - auto diffKey = item->getOperand(1); - if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) - { - auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); - return InstPair(primal, diff); - } - break; - } - } - } + auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); + return InstPair(primal, diff); } return InstPair(primal, nullptr); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index b0397069b..a870dc815 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -147,7 +147,7 @@ struct AutoDiffTranscriberBase virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0; - virtual IROp getDifferentiableMethodDictionaryItemOp() = 0; + virtual IROp getInterfaceRequirementDerivativeDecorationOp() = 0; }; } diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index f8d70c8ed..8cefa6a04 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -131,22 +131,10 @@ public: return true; if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType) return true; - auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); - if (!dictDecor) - return false; - for (auto child : dictDecor->getChildren()) - { - if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child)) - { - if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey()) - { - if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward) - return true; - if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward) - return true; - } - } - } + if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>()) + return true; + if (lookupInterfaceMethod->getRequirementKey()->findDecoration<IRForwardDerivativeDecoration>()) + return level == DifferentiableLevel::Forward; } for (; func; func = func->parent) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c2a1886fb..817edaa83 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -776,9 +776,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /* Differentiable Type Dictionary */ INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT) - /// Decorates an interface type and stores the mapping from a normal function requirement key to its derivative requirement key. - INST(DifferentiableMethodRequirementDictionaryDecoration, DifferentiableMethodRequirementDictionaryDecoration, 0, PARENT) - /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) @@ -899,16 +896,6 @@ INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDicti /* Differentiable Type Dictionary */ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) -/* DifferentiableMethodRequirementDictionaryItem */ - INST(ForwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0) - INST(BackwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0) - INST(BackwardDifferentiablePrimalMethodRequirementDictionaryItem, DifferentiablePrimalMethodRequirementDictionaryItem, 0, 0) - INST(BackwardDifferentiablePropagateMethodRequirementDictionaryItem, DifferentiablePropagateMethodRequirementDictionaryItem, 0, 0) - INST(BackwardDifferentiableIntermediateTypeRequirementDictionaryItem, DifferentiableIntermediateTypeRequirementDictionaryItem, 0, 0) - - -INST_RANGE(DifferentiableMethodRequirementDictionaryItem, ForwardDifferentiableMethodRequirementDictionaryItem, BackwardDifferentiableMethodRequirementDictionaryItem) - #undef PARENT #undef USE_OTHER #undef INST_RANGE diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4887b1c79..10c490f3c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -825,26 +825,6 @@ struct IRDifferentiableTypeDictionaryDecoration : IRDecoration IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration) }; -struct IRDifferentiableMethodRequirementDictionaryDecoration : IRDecoration -{ - IR_LEAF_ISA(DifferentiableMethodRequirementDictionaryDecoration) -}; - -struct IRDifferentiableMethodRequirementDictionaryItem : IRInst -{ - IR_PARENT_ISA(DifferentiableMethodRequirementDictionaryItem) -}; - -struct IRForwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem -{ - IR_LEAF_ISA(ForwardDifferentiableMethodRequirementDictionaryItem) -}; - -struct IRBackwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem -{ - IR_LEAF_ISA(BackwardDifferentiableMethodRequirementDictionaryItem) -}; - // An instruction that specializes another IR value // (representing a generic) to a particular set of generic arguments // (instructions representing types, witness tables, etc.) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d0527eef8..605ac62db 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6846,36 +6846,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(assocType); } - void insertRequirementKeyAssociation(IRInterfaceType* interfaceType, Decl* requirementDecl, IRInst* originalKey, IRInst* associatedKey) + void insertRequirementKeyAssociation(Decl* requirementDecl, IRInst* originalKey, IRInst* associatedKey) { - auto decor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); - if (!decor) - { - decor = - (IRDifferentiableMethodRequirementDictionaryDecoration*) - context->irBuilder->addDecoration( - interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration); - } - IROp op = kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + IROp op = kIROp_Nop; if (as<BackwardDerivativeRequirementDecl>(requirementDecl)) { - op = kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + op = kIROp_BackwardDerivativeDecoration; } else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl)) { - op = kIROp_BackwardDifferentiablePropagateMethodRequirementDictionaryItem; + op = kIROp_BackwardDerivativePropagateDecoration; } else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl)) { - op = kIROp_BackwardDifferentiablePrimalMethodRequirementDictionaryItem; + op = kIROp_BackwardDerivativePrimalDecoration; + } + else if (as<ForwardDerivativeRequirementDecl>(requirementDecl)) + { + op = kIROp_ForwardDerivativeDecoration; } - else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(requirementDecl)) + else { - op = kIROp_BackwardDifferentiableIntermediateTypeRequirementDictionaryItem; + return; } - IRInst* args[] = {originalKey, associatedKey}; - auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args); - assoc->insertAtEnd(decor); + context->irBuilder->addDecoration(originalKey, op, associatedKey); } LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl) @@ -7005,7 +6999,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>()) { auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl); - insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey); + insertRequirementKeyAssociation(diffDecl->referencedDecl, requirementKey, diffKey); } } // Add lowered requirement entry to current decl mapping to prevent |
