diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d81a33719..6c9280fca 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -195,8 +195,54 @@ public: } } + bool checkType(IRInst* type) + { + type = unwrapAttributedType(type); + if (as<IRTorchTensorType>(type)) + return false; + else if (auto arrayType = as<IRArrayTypeBase>(type)) + return checkType(arrayType->getElementType()); + else if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + if (!checkType(field->getFieldType())) + return false; + } + } + return true; + } + void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst) + { + auto outerFuncInst = maybeFindOuterGeneric(funcInst); + + if (outerFuncInst->findDecoration<IRCudaHostDecoration>()) + return; + if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>()) + return; + + // This is a kernel function, we don't allow using TorchTensor type here. + for (auto b : funcInst->getBlocks()) + { + for (auto inst : b->getChildren()) + { + if (!checkType(inst->getDataType())) + { + auto loc = inst->sourceLoc; + if (!loc.isValid()) + loc = funcInst->sourceLoc; + sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc); + return; + } + + } + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { + checkForInvalidHostTypeUsage(funcInst); + if (!_isFuncMarkedForAutoDiff(funcInst)) return; if (!funcInst->getFirstBlock()) |
