From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- .../slang/slang-ir-autodiff-transcriber-base.cpp | 112 +++++++++++---------- 1 file changed, 61 insertions(+), 51 deletions(-) (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 1b3825a7d..38a7a18bb 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -256,7 +256,7 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o return nullptr; // Special-case for differentiable existential types. - if (as(origType) || as(origType)) + if (as(origType)) { if (differentiableTypeConformanceContext.lookUpConformanceForType( origType, @@ -269,6 +269,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o else return nullptr; } + else if (as(origType)) + { + SLANG_UNEXPECTED("unexpected associated type during auto-diff"); + } auto primalType = lookupPrimalInst(builder, origType, origType); if (primalType->getOp() == kIROp_Param && primalType->getParent() && @@ -324,9 +328,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy auto primalPairType = as(primalType); return getOrCreateDiffPairType( builder, - differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - primalPairType), + differentiateType(builder, primalPairType->getValueType()), differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType( builder, primalPairType)); @@ -336,9 +338,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy { auto primalPairType = as(primalType); return builder->getDifferentialPairUserCodeType( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - primalPairType), + differentiateType(builder, primalPairType->getValueType()), differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType( builder, primalPairType)); @@ -406,6 +406,7 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType* type) case kIROp_ExtractExistentialType: case kIROp_InterfaceType: case kIROp_AssociatedType: + case kIROp_LookupWitness: return true; default: return false; @@ -460,47 +461,34 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative IRBuilder* builder, IRInst* origFunc) { - auto decor = origFunc->findDecoration(); - if (decor) - return; - // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a - // `IRUserDefinedBackwardDerivativeDecoration`. - auto udfDecor = origFunc->findDecoration(); - SLANG_RELEASE_ASSERT(udfDecor); - // We need to migrate the dictionary from the backward derivative func so we can properly - // differentiate the function header. - IRBuilder subBuilder = *builder; - subBuilder.setInsertBefore(origFunc); - - auto derivative = udfDecor->getBackwardDerivativeFunc(); - if (auto specialize = as(derivative)) - { - auto derivativeGeneric = cast(specialize->getBase()); - GenericChildrenMigrationContext migrationContext; - migrationContext.init( - derivativeGeneric, - cast(findOuterGeneric(origFunc)), - origFunc); - auto derivativeFunc = findGenericReturnVal(derivativeGeneric); - auto derivativeBlock = cast(derivativeFunc->getParent()); - for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; - dInst = dInst->getNextInst()) - { - migrationContext.cloneInst(&subBuilder, dInst); - } - auto udfDictDecor = - derivativeFunc->findDecoration(); - SLANG_RELEASE_ASSERT(udfDictDecor); - subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild()); - migrationContext.cloneInst(&subBuilder, udfDictDecor); - eliminateDeadCode(origFunc->getParent()); - } - else + // There's one corner case where our function may not have the differentiable type annotations. + // If the function is not declared differentiable, but has a custom derivative, we need to copy + // over any IRDifferentiableTypeAnnotation insts + if (auto udfDecor = origFunc->findDecoration()) { - auto udfDictDecor = derivative->findDecoration(); - if (udfDictDecor) + // We need to migrate the dictionary from the backward derivative func so we can properly + // differentiate the function header. + IRBuilder subBuilder = *builder; + subBuilder.setInsertBefore(origFunc); + + auto derivative = udfDecor->getBackwardDerivativeFunc(); + if (auto specialize = as(derivative)) { - cloneDecoration(udfDictDecor, origFunc); + auto derivativeGeneric = cast(specialize->getBase()); + + GenericChildrenMigrationContext migrationContext; + migrationContext.init( + derivativeGeneric, + cast(findOuterGeneric(origFunc)), + origFunc); + auto derivativeFunc = findGenericReturnVal(derivativeGeneric); + auto derivativeBlock = cast(derivativeFunc->getParent()); + for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; + dInst = dInst->getNextInst()) + { + migrationContext.cloneInst(&subBuilder, dInst); + } + eliminateDeadCode(origFunc->getParent()); } } } @@ -575,8 +563,8 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* else return nullptr; } - auto diffType = differentiateType(builder, originalType); - if (diffType) + + if (tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Any)) return (IRType*)getOrCreateDiffPairType(builder, originalType); return nullptr; } @@ -690,6 +678,15 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod( return InstPair(primal, diffWitness); } } + else if (as(lookupInst->getDataType())) + { + if (auto diffType = differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)) + { + return InstPair(primal, diffType); + } + } auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); @@ -997,8 +994,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene if (auto innerFunc = as(innerVal)) { maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc); - if (!innerFunc->findDecoration()) + // Is our function differentiable? + if (!(innerFunc->findDecoration() || + innerFunc->findDecoration() || + innerFunc->findDecoration() || + innerFunc->findDecoration())) + { return InstPair(origGeneric, nullptr); + } + differentiableTypeConformanceContext.setFunc(innerFunc); } else if (const auto funcType = as(innerVal)) @@ -1027,7 +1031,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene IRType* diffType = nullptr; if (primalType) { - diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + if (as(primalType)) + { + diffType = builder.getGenericKind(); + } + else + { + diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType); + } } diffGeneric->setFullType(diffType); @@ -1110,7 +1121,6 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); - if (pair.primal != pair.differential && !pair.primal->findDecoration() && !as(pair.primal)) -- cgit v1.2.3