diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-09 19:19:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-09 19:19:17 -0800 |
| commit | 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch) | |
| tree | cbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-check-decl.cpp | |
| parent | cedd93690c63188cf98e452c9d104cf51aad6c4e (diff) | |
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute.
* Fix handling around phi nodes.
* Fixes.
* Remove IR opcode for ForwardDerivativeOfDecoration.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 333e9d973..b33c33e7a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -256,6 +256,11 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context); + + void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl); + + void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr); + }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4582,11 +4587,101 @@ namespace Slang } } + void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl) + { + auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); + if (!attr) + return; + + List<Expr*> imaginaryArguments; + for (auto param : funcDecl->getParameters()) + { + auto arg = m_astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = attr->loc; + if (auto pairType = as<DifferentialPairType>(param->getType())) + { + arg->type.type = pairType->getPrimalType(); + } + imaginaryArguments.add(arg); + } + auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>()) + { + // The primal function already has a `[ForwardDerivative]` attribute, this is invalid. + getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]"); + getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + } + attr->funcExpr = calleeDeclRef; + auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>(); + fwdDerivativeAttr->loc = attr->loc; + auto outterGeneric = GetOuterGeneric(funcDecl); + auto declRef = + DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr); + declRefExpr->type.type = nullptr; + fwdDerivativeAttr->args.add(declRefExpr); + fwdDerivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr); + attr->backDeclRef = fwdDerivativeAttr->funcExpr; + fwdDerivativeAttr->funcExpr = nullptr; + return; + } + } + getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments; + for (auto param : funcDecl->getParameters()) + { + auto arg = m_astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = attr->loc; + if (auto pairType = getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); _maybeRegisterDifferentialBottomTypeConformance(newContext); + // Run checking on attributes that can't be fully checked in header checking stage. + checkDerivativeOfAttribute(decl); + if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) + checkDerivativeAttribute(decl, derivativeAttr); + if (auto body = decl->body) { checkStmt(decl->body, newContext); |
