diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d18c47689..9001295e0 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -249,6 +249,11 @@ public: if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>()) return; + bool isSynthesizeConstructor = false; + + if (auto constructor = funcInst->findDecoration<IRConstructorDecorartion>()) + isSynthesizeConstructor = constructor->getSynthesizedStatus(); + // This is a kernel function, we don't allow using TorchTensor type here. for (auto b : funcInst->getBlocks()) { @@ -256,6 +261,13 @@ public: { if (!checkType(inst->getDataType())) { + if (isSynthesizeConstructor) + { + IRBuilder irBuilder(funcInst); + irBuilder.addDecoration(funcInst, kIROp_CudaHostDecoration); + return; + } + auto loc = inst->sourceLoc; if (!loc.isValid()) loc = funcInst->sourceLoc; |
