summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-09 18:00:48 -0700
committerGitHub <noreply@github.com>2023-05-09 18:00:48 -0700
commitddebd60853b3f34bfd8e89de804fd15808abf75d (patch)
treed5d686843bc2c67e493693376a0170857998c077 /source/slang/slang-ir-check-differentiability.cpp
parent38ed03a7203baacf36fca62539ac74fd45ed42d2 (diff)
Various fixes for autodiff and slangpy. (#2876)
* Various fixes for autodiff and slangpy. * Fix cuda code gen for `select`. * Fix getBuildTagString(). * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp46
1 files changed, 46 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d81a33719..6c9280fca 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -195,8 +195,54 @@ public:
}
}
+ bool checkType(IRInst* type)
+ {
+ type = unwrapAttributedType(type);
+ if (as<IRTorchTensorType>(type))
+ return false;
+ else if (auto arrayType = as<IRArrayTypeBase>(type))
+ return checkType(arrayType->getElementType());
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ if (!checkType(field->getFieldType()))
+ return false;
+ }
+ }
+ return true;
+ }
+ void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst)
+ {
+ auto outerFuncInst = maybeFindOuterGeneric(funcInst);
+
+ if (outerFuncInst->findDecoration<IRCudaHostDecoration>())
+ return;
+ if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>())
+ return;
+
+ // This is a kernel function, we don't allow using TorchTensor type here.
+ for (auto b : funcInst->getBlocks())
+ {
+ for (auto inst : b->getChildren())
+ {
+ if (!checkType(inst->getDataType()))
+ {
+ auto loc = inst->sourceLoc;
+ if (!loc.isValid())
+ loc = funcInst->sourceLoc;
+ sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc);
+ return;
+ }
+
+ }
+ }
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
+ checkForInvalidHostTypeUsage(funcInst);
+
if (!_isFuncMarkedForAutoDiff(funcInst))
return;
if (!funcInst->getFirstBlock())