diff options
| author | Yong He <yonghe@outlook.com> | 2023-09-27 22:35:58 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-27 22:35:58 -0700 |
| commit | 12e8ce5c548f4658ef2989f368ec9d93e50d9b08 (patch) | |
| tree | abd8f1a1a897fdba349f30bf3d20046f9cdcb0ef /source/slang | |
| parent | 9833ff9a3d121b974cdaa21708eedb50e9d560cc (diff) | |
Fix regression on no_diff type transcription. (#3245)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 3 |
2 files changed, 12 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index cf2310fc8..f4a34c5aa 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -319,6 +319,9 @@ IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRI IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) { + if (isNoDiffType((IRType*)originalType)) + return nullptr; + IRInst* witness = differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType); if (witness) @@ -384,6 +387,9 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) { + if (isNoDiffType(origType)) + return nullptr; + // Special-case for differentiable existential types. if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType)) { @@ -406,6 +412,9 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRType* origType) { + if (isNoDiffType(origType)) + return nullptr; + if (auto ptrType = as<IRPtrTypeBase>(origType)) return builder->getPtrType( origType->getOp(), diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a6437e3e4..e4f3f3f94 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -813,6 +813,9 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) { + if (isNoDiffType((IRType*)primalType)) + return nullptr; + IRInst* witness = lookUpConformanceForType((IRType*)primalType); if (witness) { |
