summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-fwd.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-fwd.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-fwd.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp143
1 files changed, 119 insertions, 24 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 9f26f9d55..30c14f706 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns
return InstPair(primalVal, diffVal);
}
+InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation(
+ IRBuilder* builder,
+ IRInst* origInst)
+{
+ auto primalAnnotation =
+ as<IRDifferentiableTypeAnnotation>(maybeCloneForPrimalInst(builder, origInst));
+
+ IRDifferentiableTypeAnnotation* annotation = as<IRDifferentiableTypeAnnotation>(origInst);
+
+ differentiableTypeConformanceContext.addTypeToDictionary(
+ (IRType*)primalAnnotation->getBaseType(),
+ primalAnnotation->getWitness());
+
+ auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType());
+ if (!diffType)
+ return InstPair(primalAnnotation, nullptr);
+
+ auto diffTypeDiffWitness =
+ tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any);
+
+ IRInst* args[] = {diffType, diffTypeDiffWitness};
+
+ auto diffAnnotation = builder->emitIntrinsicInst(
+ builder->getVoidType(),
+ kIROp_DifferentiableTypeAnnotation,
+ 2,
+ args);
+
+ builder->markInstAsPrimal(diffAnnotation);
+ builder->markInstAsPrimal(primalAnnotation);
+
+ return InstPair(primalAnnotation, diffAnnotation);
+}
+
InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
@@ -752,9 +786,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto pairValType = as<IRDifferentialPairTypeBase>(
pairPtrType ? pairPtrType->getValueType() : pairType);
- auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(
- &argBuilder,
- pairValType);
+ auto diffType = differentiateType(&argBuilder, primalType);
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
@@ -795,7 +827,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (diffArg)
{
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential(
- (IRType*)diffType,
+ (IRType*)as<IRPtrTypeBase>(diffType)->getValueType(),
newVal);
markDiffTypeInst(
&afterBuilder,
@@ -827,17 +859,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
}
}
+
+ {
+ // --WORKAROUND--
+ // This is a temporary workaround for a very specific case..
+ //
+ // If all the following are true:
+ // 1. the parameter type expects a differential pair,
+ // 2. the argument is derived from a no_diff type, and
+ // 3. the argument type is a run-time type (i.e. extract_existential_type),
+ // then we need to generate a differential 0, but the IR has no
+ // information on the diff witness.
+ //
+ // We will bypass the conformance system & brute-force the lookup for the interface
+ // keys, but the proper fix is to lower this key mapping during `no_diff` lowering.
+ //
+
+ // Condition 1
+ if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType)))
+ {
+ // Condition 3
+ if (auto extractExistentialType = as<IRExtractExistentialType>(primalType))
+ {
+ // Condition 2
+ if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType()))
+ {
+ // Force-differentiate the type (this will perform a search for the witness
+ // without going through the diff-type annotation list)
+ //
+ IRInst* witnessTable = nullptr;
+ auto diffType = differentiateExtractExistentialType(
+ &argBuilder,
+ extractExistentialType,
+ witnessTable);
+
+ auto pairType =
+ getOrCreateDiffPairType(&argBuilder, primalType, witnessTable);
+ auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst(
+ differentiableTypeConformanceContext.sharedContext->zeroMethodType,
+ witnessTable,
+ differentiableTypeConformanceContext.sharedContext
+ ->zeroMethodStructKey);
+ auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr);
+ auto diffPair =
+ argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero);
+
+ args.add(diffPair);
+ continue;
+ }
+ }
+ }
+ }
+
// Argument is not differentiable.
// Add original/primal argument.
args.add(primalArg);
}
IRType* diffReturnType = nullptr;
- diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType());
+ auto primalReturnType =
+ (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
+
+ diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType);
if (!diffReturnType)
{
- diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
+ diffReturnType = primalReturnType;
}
auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args);
@@ -1035,6 +1122,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(
IRInst* diffBase = nullptr;
if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase))
{
+ auto diffType = differentiateType(builder, origSpecialize->getFullType());
if (diffBase)
{
List<IRInst*> args;
@@ -1042,11 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(
{
args.add(primalSpecialize->getArg(i));
}
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(),
- diffBase,
- args.getCount(),
- args.getBuffer());
+ auto diffSpecialize =
+ builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
else
@@ -1572,7 +1657,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc());
}
- auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ IRFunc* diffFunc = nullptr;
+
+ // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
+ // insert location unchanged). If we're transcribing it as a declaration, we should
+ // insert into the module.
+ //
+ auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
+ if (!origOuterGen || findInnerMostGenericReturnVal(origOuterGen) != origFunc)
+ {
+ // Dealing with a declaration.. insert into module scope.
+ IRBuilder subBuilder = *inBuilder;
+ subBuilder.setInsertInto(inBuilder->getModule());
+ diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc);
+ }
+ else
+ {
+ diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ }
if (auto outerGen = findOuterGeneric(diffFunc))
{
@@ -1605,7 +1707,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
IRBuilder builder = *inBuilder;
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
-
differentiableTypeConformanceContext.setFunc(origFunc);
auto diffFunc = builder.createFunc();
@@ -1632,12 +1733,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
// Transfer checkpoint hint decorations
copyCheckpointHints(&builder, origFunc, diffFunc);
-
- // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
- if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule());
- }
return diffFunc;
}
@@ -2012,6 +2107,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Reinterpret:
return transcribeReinterpret(builder, origInst);
+ case kIROp_DifferentiableTypeAnnotation:
+ return transcribeDifferentiableTypeAnnotation(builder, origInst);
+
// Differentiable insts that should have been lowered in a previous pass.
case kIROp_SwizzledStore:
{
@@ -2138,13 +2236,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(
if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType))
{
+ auto diffType = differentiateType(builder, (IRType*)origParam->getFullType());
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
- builder,
- as<IRDifferentialPairTypeBase>(diffPairType)),
- diffPairParam));
+ builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
{