summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp46
1 files changed, 46 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d81a33719..6c9280fca 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -195,8 +195,54 @@ public:
}
}
+ bool checkType(IRInst* type)
+ {
+ type = unwrapAttributedType(type);
+ if (as<IRTorchTensorType>(type))
+ return false;
+ else if (auto arrayType = as<IRArrayTypeBase>(type))
+ return checkType(arrayType->getElementType());
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ if (!checkType(field->getFieldType()))
+ return false;
+ }
+ }
+ return true;
+ }
+ void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst)
+ {
+ auto outerFuncInst = maybeFindOuterGeneric(funcInst);
+
+ if (outerFuncInst->findDecoration<IRCudaHostDecoration>())
+ return;
+ if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>())
+ return;
+
+ // This is a kernel function, we don't allow using TorchTensor type here.
+ for (auto b : funcInst->getBlocks())
+ {
+ for (auto inst : b->getChildren())
+ {
+ if (!checkType(inst->getDataType()))
+ {
+ auto loc = inst->sourceLoc;
+ if (!loc.isValid())
+ loc = funcInst->sourceLoc;
+ sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc);
+ return;
+ }
+
+ }
+ }
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
+ checkForInvalidHostTypeUsage(funcInst);
+
if (!_isFuncMarkedForAutoDiff(funcInst))
return;
if (!funcInst->getFirstBlock())