From 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 9 Nov 2022 19:19:17 -0800 Subject: Add `[ForwardDerivativeOf]` attribute. (#2501) * Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 95 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 333e9d973..b33c33e7a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -256,6 +256,11 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context); + + void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl); + + void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr); + }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4582,11 +4587,101 @@ namespace Slang } } + void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl) + { + auto attr = funcDecl->findModifier(); + if (!attr) + return; + + List imaginaryArguments; + for (auto param : funcDecl->getParameters()) + { + auto arg = m_astBuilder->create(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = attr->loc; + if (auto pairType = as(param->getType())) + { + arg->type.type = pairType->getPrimalType(); + } + imaginaryArguments.add(arg); + } + auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) + { + if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier()) + { + // The primal function already has a `[ForwardDerivative]` attribute, this is invalid. + getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]"); + getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + } + attr->funcExpr = calleeDeclRef; + auto fwdDerivativeAttr = m_astBuilder->create(); + fwdDerivativeAttr->loc = attr->loc; + auto outterGeneric = GetOuterGeneric(funcDecl); + auto declRef = + DeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr); + declRefExpr->type.type = nullptr; + fwdDerivativeAttr->args.add(declRefExpr); + fwdDerivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(as(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr); + attr->backDeclRef = fwdDerivativeAttr->funcExpr; + fwdDerivativeAttr->funcExpr = nullptr; + return; + } + } + getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List imaginaryArguments; + for (auto param : funcDecl->getParameters()) + { + auto arg = m_astBuilder->create(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = attr->loc; + if (auto pairType = getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); _maybeRegisterDifferentialBottomTypeConformance(newContext); + // Run checking on attributes that can't be fully checked in header checking stage. + checkDerivativeOfAttribute(decl); + if (auto derivativeAttr = decl->findModifier()) + checkDerivativeAttribute(decl, derivativeAttr); + if (auto body = decl->body) { checkStmt(decl->body, newContext); -- cgit v1.2.3