From ddebd60853b3f34bfd8e89de804fd15808abf75d Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 9 May 2023 18:00:48 -0700 Subject: Various fixes for autodiff and slangpy. (#2876) * Various fixes for autodiff and slangpy. * Fix cuda code gen for `select`. * Fix getBuildTagString(). * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-ir-check-differentiability.cpp | 46 +++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'source/slang/slang-ir-check-differentiability.cpp') diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d81a33719..6c9280fca 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -195,8 +195,54 @@ public: } } + bool checkType(IRInst* type) + { + type = unwrapAttributedType(type); + if (as(type)) + return false; + else if (auto arrayType = as(type)) + return checkType(arrayType->getElementType()); + else if (auto structType = as(type)) + { + for (auto field : structType->getFields()) + { + if (!checkType(field->getFieldType())) + return false; + } + } + return true; + } + void checkForInvalidHostTypeUsage(IRGlobalValueWithCode* funcInst) + { + auto outerFuncInst = maybeFindOuterGeneric(funcInst); + + if (outerFuncInst->findDecoration()) + return; + if (outerFuncInst->findDecoration()) + return; + + // This is a kernel function, we don't allow using TorchTensor type here. + for (auto b : funcInst->getBlocks()) + { + for (auto inst : b->getChildren()) + { + if (!checkType(inst->getDataType())) + { + auto loc = inst->sourceLoc; + if (!loc.isValid()) + loc = funcInst->sourceLoc; + sink->diagnose(loc, Diagnostics::invalidUseOfTorchTensorTypeInDeviceFunc); + return; + } + + } + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { + checkForInvalidHostTypeUsage(funcInst); + if (!_isFuncMarkedForAutoDiff(funcInst)) return; if (!funcInst->getFirstBlock()) -- cgit v1.2.3