From 12e8ce5c548f4658ef2989f368ec9d93e50d9b08 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 27 Sep 2023 22:35:58 -0700 Subject: Fix regression on no_diff type transcription. (#3245) Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-transcriber-base.cpp | 9 +++++++++ source/slang/slang-ir-autodiff.cpp | 3 +++ 2 files changed, 12 insertions(+) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index cf2310fc8..f4a34c5aa 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -319,6 +319,9 @@ IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRI IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) { + if (isNoDiffType((IRType*)originalType)) + return nullptr; + IRInst* witness = differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType); if (witness) @@ -384,6 +387,9 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) { + if (isNoDiffType(origType)) + return nullptr; + // Special-case for differentiable existential types. if (as(origType) || as(origType)) { @@ -406,6 +412,9 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRType* origType) { + if (isNoDiffType(origType)) + return nullptr; + if (auto ptrType = as(origType)) return builder->getPtrType( origType->getOp(), diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a6437e3e4..e4f3f3f94 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -813,6 +813,9 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) { + if (isNoDiffType((IRType*)primalType)) + return nullptr; + IRInst* witness = lookUpConformanceForType((IRType*)primalType); if (witness) { -- cgit v1.2.3