diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 3 |
2 files changed, 12 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index cf2310fc8..f4a34c5aa 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -319,6 +319,9 @@ IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRI IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) { + if (isNoDiffType((IRType*)originalType)) + return nullptr; + IRInst* witness = differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType); if (witness) @@ -384,6 +387,9 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) { + if (isNoDiffType(origType)) + return nullptr; + // Special-case for differentiable existential types. if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType)) { @@ -406,6 +412,9 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRType* origType) { + if (isNoDiffType(origType)) + return nullptr; + if (auto ptrType = as<IRPtrTypeBase>(origType)) return builder->getPtrType( origType->getOp(), diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a6437e3e4..e4f3f3f94 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -813,6 +813,9 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) { + if (isNoDiffType((IRType*)primalType)) + return nullptr; + IRInst* witness = lookUpConformanceForType((IRType*)primalType); if (witness) { |
