diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 72 |
1 files changed, 63 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 8f21e8c62..31a3072c0 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -221,7 +221,19 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI auto primalType = lookupPrimalInst(builder, originalType, nullptr); SLANG_RELEASE_ASSERT(primalType); - IRInst* witness = tryGetDifferentiableWitness(builder, originalType); + IRInst* witness = nullptr; + if (auto lookup = as<IRLookupWitnessMethod>(primalType)) + { + if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey) + { + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + autoDiffSharedContext->differentialAssocTypeWitnessStructKey); + } + } + if (!witness) + witness = tryGetDifferentiableWitness(builder, originalType); SLANG_RELEASE_ASSERT(witness); return builder->getDifferentialPairType( @@ -239,6 +251,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType); return (IRType*)findOrTranscribePrimalInst(builder, diffType); } + else if (origType->getOp() == kIROp_LookupWitness) + { + return (IRType*)findOrTranscribePrimalInst(builder, (IRInst*)primalType); + } return (IRType*)transcribe(builder, origType); } @@ -539,6 +555,39 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui { return InstPair(primal, nullptr); } + if (interfaceType == autoDiffSharedContext->differentiableInterfaceType) + { + if (primalKey == autoDiffSharedContext->differentialAssocTypeStructKey) + { + return InstPair(primal, primal); + } + else if (primalKey == autoDiffSharedContext->differentialAssocTypeWitnessStructKey) + { + return InstPair(primal, primal); + } + else + { + // We can't really differentiate a call to a IDifferentiable method here. + // They need to be specialized first. + return InstPair(primal, nullptr); + } + } + else if (auto returnWitnessType = as<IRWitnessTableTypeBase>(lookupInst->getDataType())) + { + // T.Diff_Is_IDifferential ==> T.Diff_Is_IDifferential.Diff_Is_IDifferential + if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiableInterfaceType) + { + auto primalDiffType = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + primal, + autoDiffSharedContext->differentialAssocTypeStructKey); + auto diffWitness = builder->emitLookupInterfaceMethodInst( + (IRType*)primalDiffType, + primal, + autoDiffSharedContext->differentialAssocTypeWitnessStructKey); + return InstPair(primal, diffWitness); + } + } auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); @@ -563,6 +612,8 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui // IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) { + primalType = (IRType*)unwrapAttributedType(primalType); + if (auto diffType = differentiateType(builder, primalType)) { switch (diffType->getOp()) @@ -593,17 +644,18 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - if (!zeroMethod) + IRInst* zeroMethod = nullptr; + if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) { // if the differential type itself comes from a witness lookup, we can just lookup the // zero method from the same witness table. - if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) - { - auto wt = lookupInterface->getWitnessTable(); - zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); - builder->markInstAsDifferential(zeroMethod); - } + auto wt = lookupInterface->getWitnessTable(); + zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + builder->markInstAsDifferential(zeroMethod); + } + else + { + zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); } SLANG_RELEASE_ASSERT(zeroMethod); @@ -747,6 +799,8 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_ForwardDerivativeDecoration: case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePropagateDecoration: break; |
