summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-transcriber-base.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-transcriber-base.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp112
1 files changed, 61 insertions, 51 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 1b3825a7d..38a7a18bb 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -256,7 +256,7 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
return nullptr;
// Special-case for differentiable existential types.
- if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType))
+ if (as<IRInterfaceType>(origType))
{
if (differentiableTypeConformanceContext.lookUpConformanceForType(
origType,
@@ -269,6 +269,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
else
return nullptr;
}
+ else if (as<IRAssociatedType>(origType))
+ {
+ SLANG_UNEXPECTED("unexpected associated type during auto-diff");
+ }
auto primalType = lookupPrimalInst(builder, origType, origType);
if (primalType->getOp() == kIROp_Param && primalType->getParent() &&
@@ -324,9 +328,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
auto primalPairType = as<IRDifferentialPairTypeBase>(primalType);
return getOrCreateDiffPairType(
builder,
- differentiableTypeConformanceContext.getDiffTypeFromPairType(
- builder,
- primalPairType),
+ differentiateType(builder, primalPairType->getValueType()),
differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(
builder,
primalPairType));
@@ -336,9 +338,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
{
auto primalPairType = as<IRDifferentialPairUserCodeType>(primalType);
return builder->getDifferentialPairUserCodeType(
- (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
- builder,
- primalPairType),
+ differentiateType(builder, primalPairType->getValueType()),
differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(
builder,
primalPairType));
@@ -406,6 +406,7 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType* type)
case kIROp_ExtractExistentialType:
case kIROp_InterfaceType:
case kIROp_AssociatedType:
+ case kIROp_LookupWitness:
return true;
default:
return false;
@@ -460,47 +461,34 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative
IRBuilder* builder,
IRInst* origFunc)
{
- auto decor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- if (decor)
- return;
- // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a
- // `IRUserDefinedBackwardDerivativeDecoration`.
- auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>();
- SLANG_RELEASE_ASSERT(udfDecor);
- // We need to migrate the dictionary from the backward derivative func so we can properly
- // differentiate the function header.
- IRBuilder subBuilder = *builder;
- subBuilder.setInsertBefore(origFunc);
-
- auto derivative = udfDecor->getBackwardDerivativeFunc();
- if (auto specialize = as<IRSpecialize>(derivative))
- {
- auto derivativeGeneric = cast<IRGeneric>(specialize->getBase());
- GenericChildrenMigrationContext migrationContext;
- migrationContext.init(
- derivativeGeneric,
- cast<IRGeneric>(findOuterGeneric(origFunc)),
- origFunc);
- auto derivativeFunc = findGenericReturnVal(derivativeGeneric);
- auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent());
- for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc;
- dInst = dInst->getNextInst())
- {
- migrationContext.cloneInst(&subBuilder, dInst);
- }
- auto udfDictDecor =
- derivativeFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(udfDictDecor);
- subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild());
- migrationContext.cloneInst(&subBuilder, udfDictDecor);
- eliminateDeadCode(origFunc->getParent());
- }
- else
+ // There's one corner case where our function may not have the differentiable type annotations.
+ // If the function is not declared differentiable, but has a custom derivative, we need to copy
+ // over any IRDifferentiableTypeAnnotation insts
+ if (auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
{
- auto udfDictDecor = derivative->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- if (udfDictDecor)
+ // We need to migrate the dictionary from the backward derivative func so we can properly
+ // differentiate the function header.
+ IRBuilder subBuilder = *builder;
+ subBuilder.setInsertBefore(origFunc);
+
+ auto derivative = udfDecor->getBackwardDerivativeFunc();
+ if (auto specialize = as<IRSpecialize>(derivative))
{
- cloneDecoration(udfDictDecor, origFunc);
+ auto derivativeGeneric = cast<IRGeneric>(specialize->getBase());
+
+ GenericChildrenMigrationContext migrationContext;
+ migrationContext.init(
+ derivativeGeneric,
+ cast<IRGeneric>(findOuterGeneric(origFunc)),
+ origFunc);
+ auto derivativeFunc = findGenericReturnVal(derivativeGeneric);
+ auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent());
+ for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc;
+ dInst = dInst->getNextInst())
+ {
+ migrationContext.cloneInst(&subBuilder, dInst);
+ }
+ eliminateDeadCode(origFunc->getParent());
}
}
}
@@ -575,8 +563,8 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
else
return nullptr;
}
- auto diffType = differentiateType(builder, originalType);
- if (diffType)
+
+ if (tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Any))
return (IRType*)getOrCreateDiffPairType(builder, originalType);
return nullptr;
}
@@ -690,6 +678,15 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(
return InstPair(primal, diffWitness);
}
}
+ else if (as<IRTypeKind>(lookupInst->getDataType()))
+ {
+ if (auto diffType = differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType))
+ {
+ return InstPair(primal, diffType);
+ }
+ }
auto decor = lookupInst->getRequirementKey()->findDecorationImpl(
getInterfaceRequirementDerivativeDecorationOp());
@@ -997,8 +994,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
if (auto innerFunc = as<IRFunc>(innerVal))
{
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc);
- if (!innerFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ // Is our function differentiable?
+ if (!(innerFunc->findDecoration<IRForwardDifferentiableDecoration>() ||
+ innerFunc->findDecoration<IRBackwardDifferentiableDecoration>() ||
+ innerFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>() ||
+ innerFunc->findDecoration<IRForwardDerivativeDecoration>()))
+ {
return InstPair(origGeneric, nullptr);
+ }
+
differentiableTypeConformanceContext.setFunc(innerFunc);
}
else if (const auto funcType = as<IRFuncType>(innerVal))
@@ -1027,7 +1031,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
IRType* diffType = nullptr;
if (primalType)
{
- diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType);
+ if (as<IRGenericKind>(primalType))
+ {
+ diffType = builder.getGenericKind();
+ }
+ else
+ {
+ diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType);
+ }
}
diffGeneric->setFullType(diffType);
@@ -1110,7 +1121,6 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
-
if (pair.primal != pair.differential &&
!pair.primal->findDecoration<IRAutodiffInstDecoration>() &&
!as<IRConstant>(pair.primal))