summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-transcriber-base.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp72
1 files changed, 63 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 8f21e8c62..31a3072c0 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -221,7 +221,19 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI
auto primalType = lookupPrimalInst(builder, originalType, nullptr);
SLANG_RELEASE_ASSERT(primalType);
- IRInst* witness = tryGetDifferentiableWitness(builder, originalType);
+ IRInst* witness = nullptr;
+ if (auto lookup = as<IRLookupWitnessMethod>(primalType))
+ {
+ if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey)
+ {
+ witness = builder->emitLookupInterfaceMethodInst(
+ lookup->getWitnessTable()->getDataType(),
+ lookup->getWitnessTable(),
+ autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
+ }
+ }
+ if (!witness)
+ witness = tryGetDifferentiableWitness(builder, originalType);
SLANG_RELEASE_ASSERT(witness);
return builder->getDifferentialPairType(
@@ -239,6 +251,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType);
return (IRType*)findOrTranscribePrimalInst(builder, diffType);
}
+ else if (origType->getOp() == kIROp_LookupWitness)
+ {
+ return (IRType*)findOrTranscribePrimalInst(builder, (IRInst*)primalType);
+ }
return (IRType*)transcribe(builder, origType);
}
@@ -539,6 +555,39 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
{
return InstPair(primal, nullptr);
}
+ if (interfaceType == autoDiffSharedContext->differentiableInterfaceType)
+ {
+ if (primalKey == autoDiffSharedContext->differentialAssocTypeStructKey)
+ {
+ return InstPair(primal, primal);
+ }
+ else if (primalKey == autoDiffSharedContext->differentialAssocTypeWitnessStructKey)
+ {
+ return InstPair(primal, primal);
+ }
+ else
+ {
+ // We can't really differentiate a call to a IDifferentiable method here.
+ // They need to be specialized first.
+ return InstPair(primal, nullptr);
+ }
+ }
+ else if (auto returnWitnessType = as<IRWitnessTableTypeBase>(lookupInst->getDataType()))
+ {
+ // T.Diff_Is_IDifferential ==> T.Diff_Is_IDifferential.Diff_Is_IDifferential
+ if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiableInterfaceType)
+ {
+ auto primalDiffType = builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ primal,
+ autoDiffSharedContext->differentialAssocTypeStructKey);
+ auto diffWitness = builder->emitLookupInterfaceMethodInst(
+ (IRType*)primalDiffType,
+ primal,
+ autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
+ return InstPair(primal, diffWitness);
+ }
+ }
auto decor =
lookupInst->getRequirementKey()->findDecorationImpl(
getInterfaceRequirementDerivativeDecorationOp());
@@ -563,6 +612,8 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
//
IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
{
+ primalType = (IRType*)unwrapAttributedType(primalType);
+
if (auto diffType = differentiateType(builder, primalType))
{
switch (diffType->getOp())
@@ -593,17 +644,18 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
- auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- if (!zeroMethod)
+ IRInst* zeroMethod = nullptr;
+ if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
{
// if the differential type itself comes from a witness lookup, we can just lookup the
// zero method from the same witness table.
- if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
- {
- auto wt = lookupInterface->getWitnessTable();
- zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
- builder->markInstAsDifferential(zeroMethod);
- }
+ auto wt = lookupInterface->getWitnessTable();
+ zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
+ builder->markInstAsDifferential(zeroMethod);
+ }
+ else
+ {
+ zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
}
SLANG_RELEASE_ASSERT(zeroMethod);
@@ -747,6 +799,8 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS
case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_ForwardDerivativeDecoration:
case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
case kIROp_BackwardDerivativePrimalDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
break;