diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 51 |
1 files changed, 35 insertions, 16 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 65e880868..edea3847d 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -44,6 +44,18 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK return nullptr; } +static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); +} + +static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); +} + bool isNoDiffType(IRType* paramType) { while (auto ptrType = as<IRPtrTypeBase>(paramType)) @@ -266,25 +278,13 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I return pairStructType; } -IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) -{ - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); -} - -IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) -{ - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); -} - IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; if (pairTypeCache.TryGetValue(originalPairType, result)) return result; - auto pairType = as<IRDifferentialPairType>(originalPairType); + auto pairType = as<IRDifferentialPairTypeBase>(originalPairType); if (!pairType) { result = originalPairType; @@ -297,7 +297,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } - auto diffType = getDiffTypeFromPairType(builder, pairType); + auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); if (!diffType) return result; result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); @@ -406,18 +406,28 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b } IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType( - IRBuilder* builder, IRDifferentialPairType* diffPairType) + IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType) { auto witness = diffPairType->getWitness(); SLANG_RELEASE_ASSERT(witness); return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); } +IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + return _getDiffTypeFromPairType(sharedContext, builder, type); +} + +IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + return _getDiffTypeWitnessFromPairType(sharedContext, builder, type); +} + void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { for (auto globalInst : sharedContext->moduleInst->getChildren()) { - if (auto pairType = as<IRDifferentialPairType>(globalInst)) + if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst)) { differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness()); } @@ -505,6 +515,7 @@ void stripTempDecorations(IRInst* inst) case kIROp_AutoDiffOriginalValueDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: case kIROp_PrimalValueStructKeyDecoration: + case kIROp_PrimalElementTypeDecoration: decor->removeAndDeallocate(); break; default: @@ -578,6 +589,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_TupleType: case kIROp_ArrayType: case kIROp_DifferentialPairType: + case kIROp_DifferentialPairUserCodeType: case kIROp_InterfaceType: case kIROp_AnyValueType: case kIROp_ClassType: @@ -832,6 +844,13 @@ struct AutoDiffPass : public InstPassBase if (!changed) break; + + // We have done transcribing the functions, now it is time to demote all DifferentialPair types + // and their operations down to DifferentialPairUserCodeType and *UserCode operations so they + // can be treated just like normal types with no special semantics in future processing, and won't + // be confused with the semantics of a DifferentialPair type during future autodiff code gen. + rewriteDifferentialPairToUserCode(module); + hasChanges |= changed; } |
