summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-27 22:35:58 -0700
committerGitHub <noreply@github.com>2023-09-27 22:35:58 -0700
commit12e8ce5c548f4658ef2989f368ec9d93e50d9b08 (patch)
treeabd8f1a1a897fdba349f30bf3d20046f9cdcb0ef /source/slang
parent9833ff9a3d121b974cdaa21708eedb50e9d560cc (diff)
Fix regression on no_diff type transcription. (#3245)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp9
-rw-r--r--source/slang/slang-ir-autodiff.cpp3
2 files changed, 12 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index cf2310fc8..f4a34c5aa 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -319,6 +319,9 @@ IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRI
IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType)
{
+ if (isNoDiffType((IRType*)originalType))
+ return nullptr;
+
IRInst* witness =
differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType);
if (witness)
@@ -384,6 +387,9 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI
IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType)
{
+ if (isNoDiffType(origType))
+ return nullptr;
+
// Special-case for differentiable existential types.
if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType))
{
@@ -406,6 +412,9 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRType* origType)
{
+ if (isNoDiffType(origType))
+ return nullptr;
+
if (auto ptrType = as<IRPtrTypeBase>(origType))
return builder->getPtrType(
origType->getOp(),
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index a6437e3e4..e4f3f3f94 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -813,6 +813,9 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType)
{
+ if (isNoDiffType((IRType*)primalType))
+ return nullptr;
+
IRInst* witness = lookUpConformanceForType((IRType*)primalType);
if (witness)
{