diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 44 |
1 files changed, 27 insertions, 17 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e2af70fa9..31e7b70bf 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -143,6 +143,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, void visitAssocTypeDecl(AssocTypeDecl* decl); + void checkDifferentiableCallableCommon(CallableDecl* decl); + void checkCallableDeclCommon(CallableDecl* decl); void visitFuncDecl(FuncDecl* funcDecl); @@ -9109,24 +9111,8 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( } } -void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* decl) { - for (auto paramDecl : decl->getParameters()) - { - ensureDecl(paramDecl, DeclCheckState::ReadyForReference); - } - - auto errorType = decl->errorType; - if (errorType.exp) - { - errorType = CheckProperType(errorType); - } - else - { - errorType = TypeExp(m_astBuilder->getBottomType()); - } - decl->errorType = errorType; - if (auto interfaceDecl = findParentInterfaceDecl(decl)) { bool isDiffFunc = false; @@ -9248,6 +9234,27 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) } } } +} + +void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) +{ + for (auto paramDecl : decl->getParameters()) + { + ensureDecl(paramDecl, DeclCheckState::ReadyForReference); + } + + auto errorType = decl->errorType; + if (errorType.exp) + { + errorType = CheckProperType(errorType); + } + else + { + errorType = TypeExp(m_astBuilder->getBottomType()); + } + decl->errorType = errorType; + + checkDifferentiableCallableCommon(decl); // If this method is intended to be a CUDA kernel, verify that the return type is void. if (decl->findModifier<CudaKernelAttribute>()) @@ -9709,6 +9716,8 @@ void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl) // for `GetterDecl`s. // decl->returnType.type = _getAccessorStorageType(decl); + + checkDifferentiableCallableCommon(decl); } void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) @@ -9799,6 +9808,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) newValueType); } } + checkDifferentiableCallableCommon(decl); } GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl) |
