diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-10 03:16:24 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-09 13:46:24 -0800 |
| commit | 87f00a36a123e36b415eeea82e02a8366cc5b881 (patch) | |
| tree | 719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-transcriber-base.cpp | |
| parent | 6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff) | |
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch
* More fixes
* remove intermediate dumps
* Update slang-ast-type.h
* More fixes + add a workaround for existential no-diff
* Update reverse-control-flow-3.slang
* remove dumps
* remove more dumps
* Delete working-reverse-control-flow-3.hlsl
* Cleanup comments + unused variables
* More comment cleanup
* Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)`
* Fix array of issues in Falcor tests.
* Update slang-ir-autodiff-pairs.cpp
* More fixes for Falcor image tests
* Small fixups.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 112 |
1 files changed, 61 insertions, 51 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 1b3825a7d..38a7a18bb 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -256,7 +256,7 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o return nullptr; // Special-case for differentiable existential types. - if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType)) + if (as<IRInterfaceType>(origType)) { if (differentiableTypeConformanceContext.lookUpConformanceForType( origType, @@ -269,6 +269,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o else return nullptr; } + else if (as<IRAssociatedType>(origType)) + { + SLANG_UNEXPECTED("unexpected associated type during auto-diff"); + } auto primalType = lookupPrimalInst(builder, origType, origType); if (primalType->getOp() == kIROp_Param && primalType->getParent() && @@ -324,9 +328,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy auto primalPairType = as<IRDifferentialPairTypeBase>(primalType); return getOrCreateDiffPairType( builder, - differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - primalPairType), + differentiateType(builder, primalPairType->getValueType()), differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType( builder, primalPairType)); @@ -336,9 +338,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy { auto primalPairType = as<IRDifferentialPairUserCodeType>(primalType); return builder->getDifferentialPairUserCodeType( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - primalPairType), + differentiateType(builder, primalPairType->getValueType()), differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType( builder, primalPairType)); @@ -406,6 +406,7 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType* type) case kIROp_ExtractExistentialType: case kIROp_InterfaceType: case kIROp_AssociatedType: + case kIROp_LookupWitness: return true; default: return false; @@ -460,47 +461,34 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative IRBuilder* builder, IRInst* origFunc) { - auto decor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - if (decor) - return; - // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a - // `IRUserDefinedBackwardDerivativeDecoration`. - auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>(); - SLANG_RELEASE_ASSERT(udfDecor); - // We need to migrate the dictionary from the backward derivative func so we can properly - // differentiate the function header. - IRBuilder subBuilder = *builder; - subBuilder.setInsertBefore(origFunc); - - auto derivative = udfDecor->getBackwardDerivativeFunc(); - if (auto specialize = as<IRSpecialize>(derivative)) - { - auto derivativeGeneric = cast<IRGeneric>(specialize->getBase()); - GenericChildrenMigrationContext migrationContext; - migrationContext.init( - derivativeGeneric, - cast<IRGeneric>(findOuterGeneric(origFunc)), - origFunc); - auto derivativeFunc = findGenericReturnVal(derivativeGeneric); - auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent()); - for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; - dInst = dInst->getNextInst()) - { - migrationContext.cloneInst(&subBuilder, dInst); - } - auto udfDictDecor = - derivativeFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - SLANG_RELEASE_ASSERT(udfDictDecor); - subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild()); - migrationContext.cloneInst(&subBuilder, udfDictDecor); - eliminateDeadCode(origFunc->getParent()); - } - else + // There's one corner case where our function may not have the differentiable type annotations. + // If the function is not declared differentiable, but has a custom derivative, we need to copy + // over any IRDifferentiableTypeAnnotation insts + if (auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) { - auto udfDictDecor = derivative->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - if (udfDictDecor) + // We need to migrate the dictionary from the backward derivative func so we can properly + // differentiate the function header. + IRBuilder subBuilder = *builder; + subBuilder.setInsertBefore(origFunc); + + auto derivative = udfDecor->getBackwardDerivativeFunc(); + if (auto specialize = as<IRSpecialize>(derivative)) { - cloneDecoration(udfDictDecor, origFunc); + auto derivativeGeneric = cast<IRGeneric>(specialize->getBase()); + + GenericChildrenMigrationContext migrationContext; + migrationContext.init( + derivativeGeneric, + cast<IRGeneric>(findOuterGeneric(origFunc)), + origFunc); + auto derivativeFunc = findGenericReturnVal(derivativeGeneric); + auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent()); + for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; + dInst = dInst->getNextInst()) + { + migrationContext.cloneInst(&subBuilder, dInst); + } + eliminateDeadCode(origFunc->getParent()); } } } @@ -575,8 +563,8 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* else return nullptr; } - auto diffType = differentiateType(builder, originalType); - if (diffType) + + if (tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Any)) return (IRType*)getOrCreateDiffPairType(builder, originalType); return nullptr; } @@ -690,6 +678,15 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod( return InstPair(primal, diffWitness); } } + else if (as<IRTypeKind>(lookupInst->getDataType())) + { + if (auto diffType = differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)) + { + return InstPair(primal, diffType); + } + } auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); @@ -997,8 +994,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene if (auto innerFunc = as<IRFunc>(innerVal)) { maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc); - if (!innerFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + // Is our function differentiable? + if (!(innerFunc->findDecoration<IRForwardDifferentiableDecoration>() || + innerFunc->findDecoration<IRBackwardDifferentiableDecoration>() || + innerFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>() || + innerFunc->findDecoration<IRForwardDerivativeDecoration>())) + { return InstPair(origGeneric, nullptr); + } + differentiableTypeConformanceContext.setFunc(innerFunc); } else if (const auto funcType = as<IRFuncType>(innerVal)) @@ -1027,7 +1031,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene IRType* diffType = nullptr; if (primalType) { - diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + if (as<IRGenericKind>(primalType)) + { + diffType = builder.getGenericKind(); + } + else + { + diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + } } diffGeneric->setFullType(diffType); @@ -1110,7 +1121,6 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); - if (pair.primal != pair.differential && !pair.primal->findDecoration<IRAutodiffInstDecoration>() && !as<IRConstant>(pair.primal)) |
