summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-13 10:57:28 -0700
committerGitHub <noreply@github.com>2023-03-13 10:57:28 -0700
commita911ca6e06ce41e403b80fe6054162393491c8ac (patch)
tree6c8d56a3060b1887e7fd3126fe54a1241160eddd /source
parent3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff)
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang12
-rw-r--r--source/slang/slang-check-decl.cpp42
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp160
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h4
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp123
-rw-r--r--source/slang/slang-ir-autodiff-pairs.h6
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp75
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp21
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h118
-rw-r--r--source/slang/slang-ir-autodiff.cpp51
-rw-r--r--source/slang/slang-ir-autodiff.h19
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp15
-rw-r--r--source/slang/slang-ir-inst-defs.h11
-rw-r--r--source/slang/slang-ir-insts.h44
-rw-r--r--source/slang/slang-ir-util.cpp1
-rw-r--r--source/slang/slang-ir.cpp62
-rw-r--r--source/slang/slang-ir.h12
-rw-r--r--source/slang/slang-lower-to-ir.cpp16
-rw-r--r--source/slang/slang-syntax.cpp19
19 files changed, 629 insertions, 182 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 4301eda94..ada052cd8 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -39,30 +39,30 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
-__intrinsic_type($(kIROp_DifferentialPairType))
+__intrinsic_type($(kIROp_DifferentialPairUserCodeType))
struct DifferentialPair : IDifferentiable
{
typedef DifferentialPair<T.Differential> Differential;
typedef T.Differential DifferentialElementType;
- __intrinsic_op($(kIROp_MakeDifferentialPair))
+ __intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
__init(T _primal, T.Differential _differential);
property p : T
{
- __intrinsic_op($(kIROp_DifferentialPairGetPrimal))
+ __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode))
get;
}
property v : T
{
- __intrinsic_op($(kIROp_DifferentialPairGetPrimal))
+ __intrinsic_op($(kIROp_DifferentialPairGetPrimalUserCode))
get;
}
property d : T.Differential
{
- __intrinsic_op($(kIROp_DifferentialPairGetDifferential))
+ __intrinsic_op($(kIROp_DifferentialPairGetDifferentialUserCode))
get;
}
@@ -105,7 +105,7 @@ struct DifferentialPair : IDifferentiable
};
__generic<T: IDifferentiable>
-__intrinsic_op($(kIROp_MakeDifferentialPair))
+__intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
DifferentialPair<T> diffPair(T primal, T.Differential diff);
__generic<T: IDifferentiable>
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5cd7fba45..ea8bec2bb 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1506,19 +1506,37 @@ namespace Slang
aggTypeDecl->members.add(diffField);
aggTypeDecl->invalidateMemberDictionary();
+ // Inject a `DerivativeMember` modifier on the differential field to point to itself.
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffMemberType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->create<TypeType>();
+ baseTypeType->type = differentialType;
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(diffField);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(diffField, derivativeMemberModifier);
+ }
+
// Inject a `DerivativeMember` modifier on the original decl.
- auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
- auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
- fieldLookupExpr->type.type = diffMemberType;
- auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
- baseTypeExpr->base.type = differentialType;
- auto baseTypeType = m_astBuilder->create<TypeType>();
- baseTypeType->type = differentialType;
- baseTypeExpr->type.type = baseTypeType;
- fieldLookupExpr->baseExpression = baseTypeExpr;
- fieldLookupExpr->declRef = makeDeclRef(diffField);
- derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
- addModifier(member, derivativeMemberModifier);
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffMemberType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->create<TypeType>();
+ baseTypeType->type = differentialType;
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(diffField);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
};
// Make the Differential type itself conform to `IDifferential` interface.
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 2090cd4dc..7057a5835 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -8,6 +8,9 @@
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-single-return.h"
+#include "slang-ir-addr-inst-elimination.h"
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-validate.h"
namespace Slang
{
@@ -234,7 +237,7 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig
auto primalElement = builder->emitDifferentialPairGetPrimal(load);
auto diffElement = builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
+ (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load);
return InstPair(primalElement, diffElement);
}
}
@@ -938,7 +941,10 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
if (decor)
diffAccessChain.add(decor->getDerivativeMemberStructKey());
else
- return InstPair(primalUpdateField, nullptr);
+ {
+ auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+ return InstPair(primalUpdateField, diffBase);
+ }
}
else
{
@@ -947,24 +953,26 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
}
if (auto diffType = differentiateType(builder, originalInst->getDataType()))
{
- if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+ if (!diffBase)
{
- if (auto diffVal = findOrTranscribeDiffInst(builder, origVal))
- {
- auto primalElementType = primalVal->getDataType();
+ diffBase = getDifferentialZeroOfType(builder, origBase->getDataType());
+ }
+ if (auto diffVal = findOrTranscribeDiffInst(builder, origVal))
+ {
+ auto primalElementType = primalVal->getDataType();
- diffUpdateElement = builder->emitUpdateElement(
- diffBase, diffAccessChain, diffVal);
- builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
- }
- else
- {
- auto primalElementType = primalVal->getDataType();
- auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType);
- diffUpdateElement = builder->emitUpdateElement(
- diffBase, diffAccessChain, zeroElementDiff);
- builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
- }
+ diffUpdateElement = builder->emitUpdateElement(
+ diffBase, diffAccessChain, diffVal);
+ builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
+ }
+ else
+ {
+ auto primalElementType = primalVal->getDataType();
+ auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType);
+ diffUpdateElement = builder->emitUpdateElement(
+ diffBase, diffAccessChain, zeroElementDiff);
+ builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
}
}
return InstPair(primalUpdateField, diffUpdateElement);
@@ -1121,7 +1129,7 @@ InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse*
return InstPair(diffIfElse, diffIfElse);
}
-InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
+InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst)
{
auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
SLANG_ASSERT(primalVal);
@@ -1140,9 +1148,9 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build
auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType());
auto diffPairType = findOrTranscribeDiffInst(builder, origInst->getFullType());
- auto primalPair = builder->emitMakeDifferentialPair(
+ auto primalPair = builder->emitMakeDifferentialPairUserCode(
(IRType*)primalPairType, primalVal, diffPrimalVal);
- auto diffPair = builder->emitMakeDifferentialPair(
+ auto diffPair = builder->emitMakeDifferentialPairUserCode(
(IRType*)diffPairType,
primalDiffVal,
diffDiffVal);
@@ -1152,8 +1160,8 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build
InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
{
SLANG_ASSERT(
- origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
- origInst->getOp() == kIROp_DifferentialPairGetPrimal);
+ origInst->getOp() == kIROp_DifferentialPairGetDifferentialUserCode ||
+ origInst->getOp() == kIROp_DifferentialPairGetPrimalUserCode);
auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
SLANG_ASSERT(primalVal);
@@ -1165,10 +1173,10 @@ InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder*
auto primalResult = builder->emitIntrinsicInst((IRType*)primalType, origInst->getOp(), 1, &primalVal);
- auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType());
+ auto diffValPairType = as<IRDifferentialPairUserCodeType>(diffVal->getDataType());
IRInst* diffResultType = nullptr;
- if (origInst->getOp() == kIROp_DifferentialPairGetDifferential)
- diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType);
+ if (origInst->getOp() == kIROp_DifferentialPairGetDifferentialUserCode)
+ diffResultType = differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffValPairType);
else
diffResultType = diffValPairType->getValueType();
auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
@@ -1318,6 +1326,8 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
// Mark the generated derivative function itself as differentiable.
builder.addForwardDifferentiableDecoration(diffFunc);
+ if (isBackwardDifferentiableFunc(origFunc))
+ builder.addBackwardDifferentiableDecoration(diffFunc);
// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
@@ -1349,23 +1359,105 @@ void ForwardDiffTranscriber::checkAutodiffInstDecorations(IRFunc* fwdFunc)
}
}
+void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
+{
+ IRBuilder builder(module);
+ auto firstBlock = func->getFirstBlock();
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
+ OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar;
+ List<IRParam*> params;
+ for (auto param : firstBlock->getParams())
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
+ {
+ params.add(param);
+ }
+ }
+
+ for (auto param : params)
+ {
+ auto ptrType = as<IRPtrTypeBase>(param->getDataType());
+ auto tempVar = builder.emitVar(ptrType->getValueType());
+ param->replaceUsesWith(tempVar);
+ mapParamToTempVar[param] = tempVar;
+ if (ptrType->getOp() != kIROp_OutType)
+ {
+ builder.emitStore(tempVar, builder.emitLoad(param));
+ }
+ else
+ {
+ builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType()));
+ }
+ }
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ builder.setInsertBefore(inst);
+ for (auto& kv : mapParamToTempVar)
+ {
+ builder.emitStore(kv.Key, builder.emitLoad(kv.Value));
+ }
+ }
+ }
+ }
+}
+
+struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
+{
+ DifferentiableTypeConformanceContext* diffTypeContext;
+
+ virtual bool shouldConvertAddrInst(IRInst*) override
+ {
+ return true;
+ }
+};
+
+SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
+{
+ insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
+
+ AutoDiffAddressConversionPolicy cvtPolicty;
+ cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext;
+ auto result = eliminateAddressInsts(&cvtPolicty, func, sink);
+
+ if (SLANG_SUCCEEDED(result))
+ {
+ disableIRValidationAtInsert();
+ simplifyFunc(func);
+ enableIRValidationAtInsert();
+ }
+ return result;
+}
+
// Transcribe a function definition.
InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
{
IRBuilder builder = *inBuilder;
+ builder.setInsertBefore(primalFunc);
+
+ // Create a clone for original func and run additional transformations on the clone.
+ IRCloneEnv env;
+ auto primalFuncClone = as<IRFunc>(cloneInst(&env, &builder, primalFunc));
+ prepareFuncForForwardDiff(primalFuncClone);
+
builder.setInsertInto(diffFunc);
- differentiableTypeConformanceContext.setFunc(primalFunc);
+ differentiableTypeConformanceContext.setFunc(primalFuncClone);
mapInOutParamToWriteBackValue.Clear();
// Transcribe children from origFunc into diffFunc
- for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock())
this->transcribe(&builder, block);
// Some of the transcribed blocks can appear 'out-of-order'. Although this
// shouldn't be an issue, for consistency, we put them back in order.
- for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock())
as<IRBlock>(lookupDiffInst(block))->insertAtEnd(diffFunc);
for (auto block : diffFunc->getBlocks())
@@ -1507,11 +1599,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Switch:
return transcribeSwitch(builder, as<IRSwitch>(origInst));
- case kIROp_MakeDifferentialPair:
- return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
+ case kIROp_MakeDifferentialPairUserCode:
+ return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPairUserCode>(origInst));
- case kIROp_DifferentialPairGetPrimal:
- case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimalUserCode:
+ case kIROp_DifferentialPairGetDifferentialUserCode:
return transcribeDifferentialPairGetElement(builder, origInst);
case kIROp_ExtractExistentialValue:
@@ -1612,7 +1704,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType),
diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 5b79a6c54..6032c2319 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -77,7 +77,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch);
- InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst);
+ InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst);
InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst);
@@ -100,6 +100,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
void checkAutodiffInstDecorations(IRFunc* fwdFunc);
+ SlangResult prepareFuncForForwardDiff(IRFunc* func);
+
// Create an empty func to represent the transcribed func of `origFunc`.
virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index 9d761764c..7b16c0213 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -24,10 +24,10 @@ struct DiffPairLoweringPass : InstPassBase
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
- if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
+ if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst))
{
bool isTrivial = false;
- auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
+ auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType());
if (auto loweredPairType = lowerPairType(builder, pairType))
{
builder->setInsertBefore(makePairInst);
@@ -52,7 +52,7 @@ struct DiffPairLoweringPass : InstPassBase
IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
{
- if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferentialBase>(inst))
{
auto pairType = getDiffInst->getBase()->getDataType();
if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
@@ -70,7 +70,7 @@ struct DiffPairLoweringPass : InstPassBase
return diffFieldExtract;
}
}
- else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimalBase>(inst))
{
auto pairType = getPrimalInst->getBase()->getDataType();
if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
@@ -106,10 +106,12 @@ struct DiffPairLoweringPass : InstPassBase
{
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ case kIROp_DifferentialPairGetPrimalUserCode:
lowerPairAccess(builder, inst);
break;
- case kIROp_MakeDifferentialPair:
+ case kIROp_MakeDifferentialPairUserCode:
lowerMakePair(builder, inst);
break;
@@ -119,12 +121,15 @@ struct DiffPairLoweringPass : InstPassBase
});
OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ processAllInsts([&](IRInst* inst)
{
- if (auto loweredType = lowerPairType(builder, inst))
+ if (auto pairType = as<IRDifferentialPairTypeBase>(inst))
{
- pendingReplacements.Add(inst, loweredType);
- modified = true;
+ if (auto loweredType = lowerPairType(builder, pairType))
+ {
+ pendingReplacements.Add(pairType, loweredType);
+ modified = true;
+ }
}
});
for (auto replacement : pendingReplacements)
@@ -158,4 +163,104 @@ bool processPairTypes(AutoDiffSharedContext* context)
return pairLoweringPass.processModule();
}
+struct DifferentialPairUserCodeTranscribePass : public InstPassBase
+{
+ DifferentialPairUserCodeTranscribePass(IRModule* module)
+ :InstPassBase(module)
+ {}
+
+ IRInst* rewritePairType(IRBuilder* builder, IRType* pairType)
+ {
+ builder->setInsertBefore(pairType);
+ auto originalPairType = as<IRDifferentialPairType>(pairType);
+ return builder->getDifferentialPairUserCodeType(originalPairType->getValueType(), originalPairType->getWitness());
+ }
+
+ IRInst* rewriteMakePair(IRBuilder* builder, IRMakeDifferentialPair* inst)
+ {
+ auto pairType = as<IRDifferentialPairType>(inst->getFullType());
+ builder->setInsertBefore(inst);
+ auto newInst = builder->emitMakeDifferentialPairUserCode(
+ (IRType*)pairType, inst->getPrimalValue(), inst->getDifferentialValue());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ return newInst;
+ }
+
+ IRInst* rewritePairAccess(IRBuilder* builder, IRInst* inst)
+ {
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
+ {
+ builder->setInsertBefore(inst);
+
+ auto newInst = builder->emitDifferentialPairGetDifferentialUserCode(
+ (IRType*)inst->getFullType(), getDiffInst->getBase());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ }
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
+ {
+ builder->setInsertBefore(inst);
+ auto newInst = builder->emitDifferentialPairGetPrimalUserCode(getPrimalInst->getBase());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ }
+ return inst;
+ }
+
+ bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
+ {
+ SLANG_UNUSED(instWithChildren);
+
+ bool modified = false;
+
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ rewritePairAccess(builder, inst);
+ break;
+
+ case kIROp_MakeDifferentialPair:
+ rewriteMakePair(builder, as<IRMakeDifferentialPair>(inst));
+ break;
+
+ default:
+ break;
+ }
+ });
+
+ OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ {
+ if (auto loweredType = rewritePairType(builder, inst))
+ {
+ pendingReplacements.Add(inst, loweredType);
+ modified = true;
+ }
+ });
+ for (auto replacement : pendingReplacements)
+ {
+ replacement.Key->replaceUsesWith(replacement.Value);
+ replacement.Key->removeAndDeallocate();
+ }
+
+ return modified;
+ }
+
+ bool processModule()
+ {
+ IRBuilder builder(module);
+ return processInstWithChildren(&builder, module->getModuleInst());
+ }
+};
+
+void rewriteDifferentialPairToUserCode(IRModule* module)
+{
+ DifferentialPairUserCodeTranscribePass pairRewritePass(module);
+ pairRewritePass.processModule();
+}
+
}
diff --git a/source/slang/slang-ir-autodiff-pairs.h b/source/slang/slang-ir-autodiff-pairs.h
index 44321ae9b..8f9e77145 100644
--- a/source/slang/slang-ir-autodiff-pairs.h
+++ b/source/slang/slang-ir-autodiff-pairs.h
@@ -18,4 +18,8 @@ namespace Slang
bool processPairTypes(AutoDiffSharedContext* context);
-} \ No newline at end of file
+// Rewrites all uses of `DifferentialPairType` into `DifferentialPairUserCodeType` in the original func,
+// so they are not to be confused with real mixed differential code generated by forward diff pass.
+void rewriteDifferentialPairToUserCode(IRModule* module);
+
+}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index d7cce7c53..328af4867 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -2,14 +2,12 @@
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
-#include "slang-ir-eliminate-phis.h"
#include "slang-ir-autodiff-cfg-norm.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-single-return.h"
-#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-init-local-var.h"
#include "slang-ir-redundancy-removal.h"
@@ -516,65 +514,6 @@ namespace Slang
builder.emitBranch(firstBlock);
}
- void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
- {
- IRBuilder builder(module);
- auto firstBlock = func->getFirstBlock();
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
- OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar;
- List<IRParam*> params;
- for (auto param : firstBlock->getParams())
- {
- if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
- {
- params.add(param);
- }
- }
-
- for (auto param : params)
- {
- auto ptrType = as<IRPtrTypeBase>(param->getDataType());
- auto tempVar = builder.emitVar(ptrType->getValueType());
- param->replaceUsesWith(tempVar);
- mapParamToTempVar[param] = tempVar;
- if (ptrType->getOp() != kIROp_OutType)
- {
- builder.emitStore(tempVar, builder.emitLoad(param));
- }
- else
- {
- builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType()));
- }
- }
-
- for (auto block : func->getBlocks())
- {
- for (auto inst : block->getChildren())
- {
- if (inst->getOp() == kIROp_Return)
- {
- builder.setInsertBefore(inst);
- for (auto& kv : mapParamToTempVar)
- {
- builder.emitStore(kv.Key, builder.emitLoad(kv.Value));
- }
- }
- }
- }
- }
-
-
- struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
- {
- DifferentiableTypeConformanceContext* diffTypeContext;
-
- virtual bool shouldConvertAddrInst(IRInst*) override
- {
- return true;
- }
- };
-
SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func)
{
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
@@ -592,19 +531,7 @@ namespace Slang
IRCFGNormalizationPass cfgPass = {this->getSink()};
normalizeCFG(autoDiffSharedContext->moduleInst->getModule(), func);
- insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
-
- AutoDiffAddressConversionPolicy cvtPolicty;
- cvtPolicty.diffTypeContext = &diffTypeContext;
- auto result = eliminateAddressInsts(&cvtPolicty, func, sink);
-
- if (SLANG_SUCCEEDED(result))
- {
- disableIRValidationAtInsert();
- simplifyFunc(func);
- enableIRValidationAtInsert();
- }
- return result;
+ return SLANG_OK;
}
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index ed122c862..091e7f1ab 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -304,8 +304,16 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
auto primalPairType = as<IRDifferentialPairType>(primalType);
return getOrCreateDiffPairType(
builder,
- pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
- pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType),
+ differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_DifferentialPairUserCodeType:
+ {
+ auto primalPairType = as<IRDifferentialPairUserCodeType>(primalType);
+ return builder->getDifferentialPairUserCodeType(
+ (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType),
+ differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(builder, primalPairType));
}
case kIROp_FuncType:
@@ -634,6 +642,15 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
builder->markInstAsDifferential(makeDiffPair, as<IRDifferentialPairType>(diffType)->getValueType());
return makeDiffPair;
}
+ case kIROp_DifferentialPairUserCodeType:
+ {
+ auto makeDiffPair = builder->emitMakeDifferentialPairUserCode(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairUserCodeType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairUserCodeType>(diffType)->getValueType()));
+ builder->markInstAsDifferential(makeDiffPair, as<IRDifferentialPairUserCodeType>(diffType)->getValueType());
+ return makeDiffPair;
+ }
}
if (auto arrayType = as<IRArrayType>(primalType))
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 23f57032d..1cd6a0e33 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -24,7 +24,7 @@ struct DiffTransposePass
GetElement,
GetDifferential,
FieldExtract,
-
+ DifferentialPairGetElementUserCode,
Invalid
};
@@ -1704,7 +1704,16 @@ struct DiffTransposePass
case kIROp_DifferentialPairGetDifferential:
return transposeGetDifferential(builder, as<IRDifferentialPairGetDifferential>(fwdInst), revValue);
-
+
+ case kIROp_MakeDifferentialPairUserCode:
+ return transposeMakePairUserCode(builder, as<IRMakeDifferentialPairUserCode>(fwdInst), revValue);
+
+ case kIROp_DifferentialPairGetPrimalUserCode:
+ return transposeGetPrimalUserCode(builder, as<IRDifferentialPairGetPrimalUserCode>(fwdInst), revValue);
+
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ return transposeGetDifferentialUserCode(builder, as<IRDifferentialPairGetDifferentialUserCode>(fwdInst), revValue);
+
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
case kIROp_MakeVectorFromScalar:
@@ -1878,6 +1887,47 @@ struct DiffTransposePass
fwdGetDiff)));
}
+ TranspositionResult transposeMakePairUserCode(IRBuilder* builder, IRMakeDifferentialPairUserCode* fwdMakePair, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakePair->getPrimalValue(),
+ builder->emitDifferentialPairGetPrimalUserCode(revValue),
+ fwdMakePair));
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakePair->getDifferentialValue(),
+ builder->emitDifferentialPairGetDifferentialUserCode(
+ fwdMakePair->getDifferentialValue()->getFullType(), revValue),
+ fwdMakePair));
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeGetDifferentialUserCode(IRBuilder*, IRDifferentialPairGetDifferentialUserCode* fwdGetDiff, IRInst* revValue)
+ {
+ // (A = x.p) -> (dX = DiffPairUserCode(dA, 0))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::DifferentialPairGetElementUserCode,
+ fwdGetDiff->getBase(),
+ revValue,
+ fwdGetDiff)));
+ }
+
+ TranspositionResult transposeGetPrimalUserCode(IRBuilder*, IRDifferentialPairGetPrimalUserCode* fwdGetPrimal, IRInst* revValue)
+ {
+ // (A = x.p) -> (dX = DiffPairUserCode(0, dA))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::DifferentialPairGetElementUserCode,
+ fwdGetPrimal->getBase(),
+ revValue,
+ fwdGetPrimal)));
+ }
+
TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
{
auto vectorType = as<IRVectorType>(revValue->getDataType());
@@ -2497,6 +2547,40 @@ struct DiffTransposePass
return materializeSimpleGradients(builder, aggPrimalType, simpleGradients);
}
+ RevGradient materializeDifferentialPairUserCodeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
+ {
+ List<RevGradient> simpleGradients;
+
+ for (auto gradient : gradients)
+ {
+ // Peek at the fwd-mode get element inst to see what type we need to materialize.
+ if (auto fwdGetDiff = as<IRDifferentialPairGetDifferentialUserCode>(gradient.fwdGradInst))
+ {
+ auto baseType = as<IRDifferentialPairUserCodeType>(diffTypeContext.getDifferentialForType(
+ builder,
+ fwdGetDiff->getBase()->getDataType()));
+ simpleGradients.add(
+ RevGradient(
+ gradient.targetInst,
+ builder->emitMakeDifferentialPairUserCode(baseType, emitDZeroOfDiffInstType(builder, baseType->getValueType()), gradient.revGradInst),
+ gradient.fwdGradInst));
+ }
+ else if (auto fwdGetPrimal = as<IRDifferentialPairGetPrimalUserCode>(gradient.fwdGradInst))
+ {
+ auto baseType = as<IRDifferentialPairUserCodeType>(diffTypeContext.getDifferentialForType(
+ builder,
+ fwdGetPrimal->getBase()->getDataType()));
+ simpleGradients.add(
+ RevGradient(
+ gradient.targetInst,
+ builder->emitMakeDifferentialPairUserCode(baseType, gradient.revGradInst, emitDZeroOfDiffInstType(builder, fwdGetPrimal->getFullType())),
+ gradient.fwdGradInst));
+ }
+ }
+
+ return materializeSimpleGradients(builder, aggPrimalType, simpleGradients);
+ }
+
RevGradient materializeGradientSet(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
{
switch (gradients[0].flavor)
@@ -2513,6 +2597,9 @@ struct DiffTransposePass
case RevGradient::Flavor::GetElement:
return materializeGetElementGradients(builder, aggPrimalType, gradients);
+ case RevGradient::Flavor::DifferentialPairGetElementUserCode:
+ return materializeDifferentialPairUserCodeGetElementGradients(builder, aggPrimalType, gradients);
+
default:
SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization");
}
@@ -2773,6 +2860,16 @@ struct DiffTransposePass
auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType());
return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
}
+ else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
+ {
+ auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType());
+ auto diffZero = primalZero;
+ auto diffType = primalZero->getFullType();
+ auto diffWitness = diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType);
+ auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
+ return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero);
+ }
+
auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
// Should exist.
@@ -2810,6 +2907,23 @@ struct DiffTransposePass
SLANG_UNIMPLEMENTED_X("dadd of dynamic array.");
}
}
+ else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
+ {
+ auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, diffPairUserType);
+ auto diffWitness = diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType);
+
+ auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1);
+ auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2);
+ auto primal = emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2);
+
+ auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1);
+ auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2);
+ auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2);
+
+ auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
+ return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff);
+ }
+
auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType);
// Should exist.
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 65e880868..edea3847d 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -44,6 +44,18 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
return nullptr;
}
+static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
+}
+
+static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+}
+
bool isNoDiffType(IRType* paramType)
{
while (auto ptrType = as<IRPtrTypeBase>(paramType))
@@ -266,25 +278,13 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I
return pairStructType;
}
-IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
-{
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
-}
-
-IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
-{
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
-}
-
IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
IRBuilder* builder, IRType* originalPairType)
{
IRInst* result = nullptr;
if (pairTypeCache.TryGetValue(originalPairType, result))
return result;
- auto pairType = as<IRDifferentialPairType>(originalPairType);
+ auto pairType = as<IRDifferentialPairTypeBase>(originalPairType);
if (!pairType)
{
result = originalPairType;
@@ -297,7 +297,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
return result;
}
- auto diffType = getDiffTypeFromPairType(builder, pairType);
+ auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
if (!diffType)
return result;
result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
@@ -406,18 +406,28 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b
}
IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType(
- IRBuilder* builder, IRDifferentialPairType* diffPairType)
+ IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType)
{
auto witness = diffPairType->getWitness();
SLANG_RELEASE_ASSERT(witness);
return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
}
+IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ return _getDiffTypeFromPairType(sharedContext, builder, type);
+}
+
+IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ return _getDiffTypeWitnessFromPairType(sharedContext, builder, type);
+}
+
void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
for (auto globalInst : sharedContext->moduleInst->getChildren())
{
- if (auto pairType = as<IRDifferentialPairType>(globalInst))
+ if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst))
{
differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness());
}
@@ -505,6 +515,7 @@ void stripTempDecorations(IRInst* inst)
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_BackwardDerivativePrimalReturnDecoration:
case kIROp_PrimalValueStructKeyDecoration:
+ case kIROp_PrimalElementTypeDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -578,6 +589,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_TupleType:
case kIROp_ArrayType:
case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPairUserCodeType:
case kIROp_InterfaceType:
case kIROp_AnyValueType:
case kIROp_ClassType:
@@ -832,6 +844,13 @@ struct AutoDiffPass : public InstPassBase
if (!changed)
break;
+
+ // We have done transcribing the functions, now it is time to demote all DifferentialPair types
+ // and their operations down to DifferentialPairUserCodeType and *UserCode operations so they
+ // can be treated just like normal types with no special semantics in future processing, and won't
+ // be confused with the semantics of a DifferentialPair type during future autodiff code gen.
+ rewriteDifferentialPairToUserCode(module);
+
hasChanges |= changed;
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index f757375d8..e7a841323 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -159,7 +159,11 @@ struct DifferentiableTypeConformanceContext
IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
- IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairType* diffPairType);
+ IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType);
+
+ IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+
+ IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
// Lookup and return the 'Differential' type declared in the concrete type
// in order to conform to the IDifferentiable interface.
@@ -180,6 +184,13 @@ struct DifferentiableTypeConformanceContext
diffElementType,
as<IRArrayType>(origType)->getElementCount());
}
+ case kIROp_DifferentialPairUserCodeType:
+ {
+ auto diffPairType = as<IRDifferentialPairTypeBase>(origType);
+ auto diffType = getDiffTypeFromPairType(builder, diffPairType);
+ auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType);
+ return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness);
+ }
default:
return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
}
@@ -194,6 +205,8 @@ struct DifferentiableTypeConformanceContext
case kIROp_FloatType:
case kIROp_HalfType:
case kIROp_DoubleType:
+ case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPairUserCodeType:
return true;
case kIROp_VectorType:
case kIROp_ArrayType:
@@ -244,10 +257,6 @@ struct DifferentialPairTypeBuilder
IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType);
- IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
-
- IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
-
IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);
struct PairStructKey
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 14f6394e2..6f97ce076 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -73,16 +73,13 @@ public:
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
- if (level == DifferentiableLevel::Forward)
+ switch (func->getOp())
{
- switch (func->getOp())
- {
- case kIROp_ForwardDifferentiate:
- case kIROp_BackwardDifferentiate:
- return true;
- default:
- break;
- }
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return isDifferentiableFunc(func->getOperand(0), level);
+ default:
+ break;
}
func = getResolvedInstForDecorations(func);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 7411d031c..28c682c91 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -62,6 +62,9 @@ INST(Nop, nop, 0, 0)
INST(OptionalType, Optional, 1, HOISTABLE)
INST(DifferentialPairType, DiffPair, 1, HOISTABLE)
+ INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE)
+ INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPairUserCodeType)
+
INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE)
/* BindExistentialsTypeBase */
@@ -278,8 +281,16 @@ INST(undefined, undefined, 0, 0)
INST(DefaultConstruct, defaultConstruct, 0, 0)
INST(MakeDifferentialPair, MakeDiffPair, 2, 0)
+INST(MakeDifferentialPairUserCode, MakeDiffPairUserCode, 2, 0)
+INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPairUserCode)
+
INST(DifferentialPairGetDifferential, GetDifferential, 1, 0)
+INST(DifferentialPairGetDifferentialUserCode, GetDifferentialUserCode, 1, 0)
+INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPairGetDifferentialUserCode)
+
INST(DifferentialPairGetPrimal, GetPrimal, 1, 0)
+INST(DifferentialPairGetPrimalUserCode, GetPrimalUserCode, 1, 0)
+INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPairGetPrimalUserCode)
INST(Specialize, specialize, 2, HOISTABLE)
INST(LookupWitness, lookupWitness, 2, HOISTABLE)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index ae31219bd..f3181ecf7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2259,24 +2259,49 @@ struct IRGetTupleElement : IRInst
// An Instruction that creates a differential pair value from a
// primal and differential.
-struct IRMakeDifferentialPair : IRInst
+
+struct IRMakeDifferentialPairBase : IRInst
{
- IR_LEAF_ISA(MakeDifferentialPair)
+ IR_PARENT_ISA(MakeDifferentialPairBase)
IRInst* getPrimalValue() { return getOperand(0); }
IRInst* getDifferentialValue() { return getOperand(1); }
};
+struct IRMakeDifferentialPair : IRMakeDifferentialPairBase
+{
+ IR_LEAF_ISA(MakeDifferentialPair)
+};
+struct IRMakeDifferentialPairUserCode : IRMakeDifferentialPairBase
+{
+ IR_LEAF_ISA(MakeDifferentialPairUserCode)
+};
-struct IRDifferentialPairGetDifferential : IRInst
+struct IRDifferentialPairGetDifferentialBase : IRInst
{
- IR_LEAF_ISA(DifferentialPairGetDifferential)
+ IR_PARENT_ISA(DifferentialPairGetDifferentialBase)
IRInst* getBase() { return getOperand(0); }
};
+struct IRDifferentialPairGetDifferential : IRDifferentialPairGetDifferentialBase
+{
+ IR_LEAF_ISA(DifferentialPairGetDifferential)
+};
+struct IRDifferentialPairGetDifferentialUserCode : IRDifferentialPairGetDifferentialBase
+{
+ IR_LEAF_ISA(DifferentialPairGetDifferentialUserCode)
+};
-struct IRDifferentialPairGetPrimal : IRInst
+struct IRDifferentialPairGetPrimalBase : IRInst
{
- IR_LEAF_ISA(DifferentialPairGetPrimal)
+ IR_PARENT_ISA(DifferentialPairGetPrimalBase)
IRInst* getBase() { return getOperand(0); }
};
+struct IRDifferentialPairGetPrimal : IRDifferentialPairGetPrimalBase
+{
+ IR_LEAF_ISA(DifferentialPairGetPrimal)
+};
+struct IRDifferentialPairGetPrimalUserCode : IRDifferentialPairGetPrimalBase
+{
+ IR_LEAF_ISA(DifferentialPairGetPrimalUserCode)
+};
struct IRDetachDerivative : IRInst
{
@@ -2717,6 +2742,10 @@ public:
IRType* valueType,
IRInst* witnessTable);
+ IRDifferentialPairUserCodeType* getDifferentialPairUserCodeType(
+ IRType* valueType,
+ IRInst* witnessTable);
+
IRBackwardDiffIntermediateContextType* getBackwardDiffIntermediateContextType(IRInst* func);
IRFuncType* getFuncType(
@@ -2832,6 +2861,7 @@ public:
IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
+ IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential);
IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target);
@@ -2966,6 +2996,8 @@ public:
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
+ IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair);
+ IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 13920b011..254734965 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -174,6 +174,7 @@ bool isValueType(IRInst* dataType)
case kIROp_ResultType:
case kIROp_OptionalType:
case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPairUserCodeType:
case kIROp_DynamicType:
case kIROp_AnyValueType:
case kIROp_ArrayType:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2819a6d83..08c066f5d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2459,6 +2459,26 @@ namespace Slang
if (found)
{
memoryArena.rewindToCursor(cursor);
+
+ // If the found inst is defined in the same parent as current insert location but
+ // is located after the insert location, we need to move it to the insert location.
+ auto foundInst = *found;
+ if (foundInst->getParent() && foundInst->getParent() == getInsertLoc().getParent() &&
+ getInsertLoc().getMode() == IRInsertLoc::Mode::Before)
+ {
+ auto insertLoc = getInsertLoc().getInst();
+ bool isAfter = false;
+ for (auto cur = insertLoc->next; cur; cur = cur->next)
+ {
+ if (cur == foundInst)
+ {
+ isAfter = true;
+ break;
+ }
+ }
+ if (isAfter)
+ foundInst->insertBefore(insertLoc);
+ }
return *found;
}
}
@@ -2779,6 +2799,17 @@ namespace Slang
operands);
}
+ IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType(
+ IRType* valueType,
+ IRInst* witnessTable)
+ {
+ IRInst* operands[] = { valueType, witnessTable };
+ return (IRDifferentialPairUserCodeType*)getType(
+ kIROp_DifferentialPairUserCodeType,
+ sizeof(operands) / sizeof(operands[0]),
+ operands);
+ }
+
IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType(
IRInst* func)
{
@@ -3162,6 +3193,18 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential)
+ {
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type));
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type)->getValueType() != nullptr);
+
+ IRInst* args[] = { primal, differential };
+ auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
+ this, kIROp_MakeDifferentialPairUserCode, type, 2, args);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitSpecializeInst(
IRType* type,
IRInst* genericVal,
@@ -3751,6 +3794,25 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
+ {
+ SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPairGetDifferentialUserCode,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimalUserCode(IRInst* diffPair)
+ {
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPairGetPrimalUserCode,
+ 1,
+ &diffPair);
+ }
IRInst* IRBuilder::emitMakeMatrix(
IRType* type,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index e22ea8a36..14a216fd2 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1448,14 +1448,24 @@ SIMPLE_IR_TYPE(TypeKind, Kind);
//
SIMPLE_IR_TYPE(GenericKind, Kind)
-struct IRDifferentialPairType : IRType
+struct IRDifferentialPairTypeBase : IRType
{
IRType* getValueType() { return (IRType*)getOperand(0); }
IRInst* getWitness() { return (IRInst*)getOperand(1); }
+ IR_PARENT_ISA(DifferentialPairTypeBase)
+};
+
+struct IRDifferentialPairType : IRDifferentialPairTypeBase
+{
IR_LEAF_ISA(DifferentialPairType)
};
+struct IRDifferentialPairUserCodeType : IRDifferentialPairTypeBase
+{
+ IR_LEAF_ISA(DifferentialPairUserCodeType)
+};
+
struct IRBackwardDiffIntermediateContextType : IRType
{
IRInst* getFunc() { return getOperand(0); }
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d8912cbd4..5e6213205 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7285,10 +7285,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount());
}
- void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember)
+ void lowerDerivativeMemberModifier(IRInst* inst, Decl* memberDecl, DerivativeMemberAttribute* derivativeMember)
{
- ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl);
- auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val;
+ IRInst* key = nullptr;
+ if (derivativeMember->memberDeclRef->declRef.getDecl() == memberDecl)
+ {
+ key = inst;
+ }
+ else
+ {
+ ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl);
+ key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val;
+ }
SLANG_RELEASE_ASSERT(as<IRStructKey>(key));
auto builder = getBuilder();
builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key);
@@ -7358,7 +7366,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
if (auto derivativeMemberModifier = fieldDecl->findModifier<DerivativeMemberAttribute>())
{
- lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier);
+ lowerDerivativeMemberModifier(irFieldKey, fieldDecl, derivativeMemberModifier);
}
// We allow a field to be marked as a target intrinsic,
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 6076a41ca..470f5f983 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1232,6 +1232,25 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
break;
}
+
+ // Hard code implementation of T.Differential.Differential == T.Differential rule.
+ if (auto builtinReq = substDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
+ {
+ if (builtinReq->kind == BuiltinRequirementKind::DifferentialType)
+ {
+ // Is the concrete type a Differential associated type?
+ if (auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub))
+ {
+ if (auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier<BuiltinRequirementModifier>())
+ {
+ if (innerBuiltinReq->kind == BuiltinRequirementKind::DifferentialType)
+ {
+ return innerDeclRefType;
+ }
+ }
+ }
+ }
+ }
}
}
}