diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-09 18:00:48 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-09 18:00:48 -0700 |
| commit | ddebd60853b3f34bfd8e89de804fd15808abf75d (patch) | |
| tree | d5d686843bc2c67e493693376a0170857998c077 /source/slang/slang-ir-check-differentiability.cpp | |
| parent | 38ed03a7203baacf36fca62539ac74fd45ed42d2 (diff) | |
Various fixes for autodiff and slangpy. (#2876)
* Various fixes for autodiff and slangpy.
* Fix cuda code gen for `select`.
* Fix getBuildTagString().
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d81a33719..6c9280fca 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -195,8 +195,54 @@ public: } } + bool checkType(IRInst* type) + { + type = unwrapAttributedType(type); + if (as<IRTorchTensorType>(type)) + return false; + else if (auto arrayType = as<IRArrayTypeBase>(type)) + return checkType(arrayType->getElementType()); + else if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + if (!checkType(field->getFieldType())) + return false; + } + } + return true; + } + void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst) + { + auto outerFuncInst = maybeFindOuterGeneric(funcInst); + + if (outerFuncInst->findDecoration<IRCudaHostDecoration>()) + return; + if (outerFuncInst->findDecoration<IRTorchEntryPointDecoration>()) + return; + + // This is a kernel function, we don't allow using TorchTensor type here. + for (auto b : funcInst->getBlocks()) + { + for (auto inst : b->getChildren()) + { + if (!checkType(inst->getDataType())) + { + auto loc = inst->sourceLoc; + if (!loc.isValid()) + loc = funcInst->sourceLoc; + sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc); + return; + } + + } + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { + checkForInvalidHostTypeUsage(funcInst); + if (!_isFuncMarkedForAutoDiff(funcInst)) return; if (!funcInst->getFirstBlock()) |
