summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp12
1 files changed, 12 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d18c47689..9001295e0 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -249,6 +249,11 @@ public:
if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>())
return;
+ bool isSynthesizeConstructor = false;
+
+ if (auto constructor = funcInst->findDecoration<IRConstructorDecorartion>())
+ isSynthesizeConstructor = constructor->getSynthesizedStatus();
+
// This is a kernel function, we don't allow using TorchTensor type here.
for (auto b : funcInst->getBlocks())
{
@@ -256,6 +261,13 @@ public:
{
if (!checkType(inst->getDataType()))
{
+ if (isSynthesizeConstructor)
+ {
+ IRBuilder irBuilder(funcInst);
+ irBuilder.addDecoration(funcInst, kIROp_CudaHostDecoration);
+ return;
+ }
+
auto loc = inst->sourceLoc;
if (!loc.isValid())
loc = funcInst->sourceLoc;