diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-10-03 16:02:16 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-03 16:02:16 -0400 |
| commit | 9f246a43667b4893040669873400e2e3813328ff (patch) | |
| tree | f1fafe8c266b1db6f5f2cb76ab4fb7332cc2be54 /source | |
| parent | aa64c853142076b17bd020f1386ea5fc6fcd5e3e (diff) | |
Support custom derivatives of member functions of differentiable types (#5124)
* Initial work to support custom derivatives for member methods of differentiable types
* Support custom derivatives of member functions of differentiable types
- Also adds support for declaring custom derivatives via extensions.
* Fix
* move defs
* Update slang-check-decl.cpp
* Create diff-member-func-custom-derivative.slang.expected.txt
* Update slang-check-decl.cpp
* Fix for static custom derivatives
* Fix diagnostics for [PreferRecompute]
* Add backward custom derivative tests
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 269 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 5 |
3 files changed, 238 insertions, 40 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4d742d9f4..8c3429c9a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7733,13 +7733,13 @@ namespace Slang // Two such constraints are equivalent if their `sub` // and `sup` types are pairwise equivalent. // - auto leftSub = leftConstraint->sub; - auto rightSub = getSub(m_astBuilder, rightConstraint); + auto leftSub = leftConstraint->sub.type; + auto rightSub = substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sub.type); if(!leftSub->equals(rightSub)) return false; - auto leftSup = leftConstraint->sup; - auto rightSup = getSup(m_astBuilder, rightConstraint); + auto leftSup = leftConstraint->sup.type; + auto rightSup = substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sup.type); if(!leftSup->equals(rightSup)) return false; } @@ -10339,10 +10339,67 @@ namespace Slang return result; } + bool areTypesCompatibile(SemanticsVisitor* visitor, Type* fst, Type* snd) + { + if (fst->equals(snd)) + return true; + + if (auto declRefType = as<DeclRefType>(fst)) + { + auto decl = declRefType->getDeclRef().getDecl(); + if (auto extGenericDecl = visitor->GetOuterGeneric(decl)) + { + SemanticsVisitor::ConstraintSystem constraints; + constraints.loc = decl->loc; + constraints.genericDecl = extGenericDecl; + + if (!visitor->TryUnifyTypes(constraints, SemanticsVisitor::ValUnificationContext(), fst, snd)) + return false; + + ConversionCost baseCost; + if (!visitor->trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>(), baseCost)) + return false; + + // If we reach here, it means we have a valid unification. + return true; + } + } + return false; + } + + Type* getTypeForThisExpr(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl) + { + ThisExpr* expr = visitor->getASTBuilder()->create<ThisExpr>(); + expr->scope = funcDecl->ownedScope; + expr->loc = funcDecl->loc; + + DiagnosticSink dummySink; + auto tempVisitor = SemanticsVisitor(visitor->withSink(&dummySink)); + + auto checkedExpr = tempVisitor.CheckTerm(expr); + + return !(as<ErrorType>(checkedExpr->type.type)) ? (checkedExpr->type.type) : nullptr; + } + + Type* getTypeForThisExpr(SemanticsVisitor* visitor, DeclRef<FunctionDeclBase> funcDeclRef) + { + auto type = getTypeForThisExpr(visitor, funcDeclRef.getDecl()); + if (type) + return substituteType( + SubstitutionSet(funcDeclRef.declRefBase), + visitor->getASTBuilder(), + type); + return nullptr; + } + + struct ArgsWithDirectionInfo { List<Expr*> args; List<ParameterDirection> directions; + + Expr* thisArg; + ParameterDirection thisArgDirection; }; template<typename TDerivativeAttr> @@ -10351,7 +10408,9 @@ namespace Slang Decl* funcDecl, TDerivativeAttr* attr, const List<Expr*>& imaginaryArguments, - const List<ParameterDirection>& expectedParamDirections) + const List<ParameterDirection>& expectedParamDirections, + Expr* expectedThisArg, + ParameterDirection expectedThisArgDirection) { if (isInterfaceRequirement(funcDecl)) { @@ -10402,7 +10461,18 @@ namespace Slang return type->toString(); }; - auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); + List<Expr*> argList = imaginaryArguments; + List<ParameterDirection> paramDirections = expectedParamDirections; + bool expectStaticFunc = false; + + if (expectedThisArg) + { + argList.insert(0, expectedThisArg); + paramDirections.insert(0, expectedThisArgDirection); + expectStaticFunc = true; + } + + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, argList); auto resolved = subVisitor.ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as<InvokeExpr>(resolved)) @@ -10430,61 +10500,104 @@ namespace Slang visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative); return; } - if (funcType->getParamCount() != imaginaryArguments.getCount()) + if (funcType->getParamCount() != argList.getCount()) { goto error; } - for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) + for (Index ii = 0; ii < argList.getCount(); ++ii) { // Check if the resolved invoke argument type is an error type. // If so, then we have a type mismatch. // if (resolvedInvoke->arguments[ii]->type.type->equals(ctx.getASTBuilder()->getErrorType()) || - funcType->getParamDirection(ii) != expectedParamDirections[ii]) + funcType->getParamDirection(ii) != paramDirections[ii]) { visitor->getSink()->diagnose( attr, Diagnostics::customDerivativeSignatureMismatchAtPosition, ii, - qualTypeToString(imaginaryArguments[ii]->type), + qualTypeToString(argList[ii]->type), funcType->getParamType(ii)->toString()); } } // The `imaginaryArguments` list does not include the `this` parameter. // So we need to check that `this` type matches. bool funcIsStatic = isEffectivelyStatic(funcDecl); + if (funcIsStatic) + expectStaticFunc = true; + bool derivativeFuncIsStatic = isEffectivelyStatic(calleeDeclRef->declRef.getDecl()); - if (funcIsStatic != derivativeFuncIsStatic) + + if (expectStaticFunc && !derivativeFuncIsStatic) { visitor->getSink()->diagnose( attr, - Diagnostics::customDerivativeSignatureThisParamMismatch); + Diagnostics::customDerivativeExpectedStatic); return; } - if (!funcIsStatic) + + if (!derivativeFuncIsStatic) { auto defaultFuncDeclRef = createDefaultSubstitutionsIfNeeded( visitor->getASTBuilder(), visitor, makeDeclRef(funcDecl)); - auto funcThisType = visitor->calcThisType(defaultFuncDeclRef); - auto derivativeFuncThisType = visitor->calcThisType(calleeDeclRef->declRef); - if (!funcThisType->equals(derivativeFuncThisType)) + + DeclRef<FunctionDeclBase> funcDeclRef = defaultFuncDeclRef.as<FunctionDeclBase>(); + auto funcThisType = getTypeForThisExpr(visitor, funcDeclRef); + DeclRef<FunctionDeclBase> calleeFuncDeclRef = calleeDeclRef->declRef.template as<FunctionDeclBase>(); + auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef); + + // If the function is a member function, we need to check that the + // `this` type matches the expected type. This will ensure that after lowering to IR, + // the two functions are compatible. + // + if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType)) { visitor->getSink()->diagnose( attr, Diagnostics::customDerivativeSignatureThisParamMismatch); return; } - if (visitor->isTypeDifferentiable(funcThisType)) + } + + // If the two decls are under different generic contexts, we'll need to check that + // they agree and specialize the attribute's decl-ref accordingly. + // + + auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl)); + auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeDeclRef->declRef.getDecl())); + + if ((!originalNextGeneric) != (!derivativeNextGeneric)) + { + // Diagnostic for when one is generic and the other is not. + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } + + if (originalNextGeneric != derivativeNextGeneric) + { + // If the two generic containers are not the same, but are compatible, we can + // unify them. + // + + DeclRef<Decl> specializedDecl; + if (!visitor->doGenericSignaturesMatch(originalNextGeneric, derivativeNextGeneric, &specializedDecl)) { - visitor->getSink()->diagnose( - attr, - Diagnostics::customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType); + visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeSignatureMismatch); return; } - } + calleeDeclRef->declRef = substituteDeclRef( + SubstitutionSet(specializedDecl), + visitor->getASTBuilder(), + calleeDeclRef->declRef); + calleeDeclRef->type = substituteType( + SubstitutionSet(specializedDecl), + visitor->getASTBuilder(), + calleeDeclRef->type); + } + attr->funcExpr = calleeDeclRef; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -10497,12 +10610,12 @@ namespace Slang // StringBuilder builder; builder << "("; - for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii) + for (Index ii = 0; ii < argList.getCount(); ++ii) { if (ii != 0) builder << ", "; - if (imaginaryArguments[ii]->type) - builder << qualTypeToString(imaginaryArguments[ii]->type); + if (argList[ii]->type) + builder << qualTypeToString(argList[ii]->type); else builder << "<error>"; } @@ -10544,11 +10657,36 @@ namespace Slang imaginaryArguments.add(arg); directions.add(getParameterDirection(param)); } - return { imaginaryArguments, directions }; + return { imaginaryArguments, directions, nullptr, ParameterDirection::kParameterDirection_In }; } ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { + Expr* thisArgExpr = nullptr; + if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl)) + { + thisArgExpr = visitor->getASTBuilder()->create<VarExpr>(); + thisArgExpr->type = thisType; + thisArgExpr->loc = loc; + + if (visitor->isTypeDifferentiable(thisType) && + !originalFuncDecl->findModifier<NoDiffThisAttribute>() && + !isEffectivelyStatic(originalFuncDecl)) + { + auto pairType = visitor->getDifferentialPairType(thisType); + thisArgExpr->type.type = pairType; + } + else + { + thisArgExpr = nullptr; + } + } + + ParameterDirection thisTypeDirection = + (thisArgExpr && !thisArgExpr->type.isLeftValue) ? + ParameterDirection::kParameterDirection_In : + ParameterDirection::kParameterDirection_InOut; + List<Expr*> imaginaryArguments; for (auto param : originalFuncDecl->getParameters()) { @@ -10574,11 +10712,40 @@ namespace Slang expectedParamDirections.add(getParameterDirection(param)); } - return { imaginaryArguments, expectedParamDirections }; + return { imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection }; } ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { + Expr* thisArgExpr = nullptr; + if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl)) + { + thisArgExpr = visitor->getASTBuilder()->create<VarExpr>(); + thisArgExpr->type = thisType; + thisArgExpr->loc = loc; + + if (visitor->isTypeDifferentiable(thisType) && + !originalFuncDecl->findModifier<NoDiffThisAttribute>() && + !isEffectivelyStatic(originalFuncDecl)) + { + auto pairType = visitor->getDifferentialPairType(thisType); + thisArgExpr->type.type = pairType; + + // TODO: for ptr pair types, no need to set isLeftValue to true. + if (as<DifferentialPairType>(thisArgExpr->type.type)) + thisArgExpr->type.isLeftValue = true; + } + else + { + thisArgExpr = nullptr; + } + } + + ParameterDirection thisTypeDirection = + (thisArgExpr && !thisArgExpr->type.isLeftValue) ? + ParameterDirection::kParameterDirection_In : + ParameterDirection::kParameterDirection_InOut; + List<Expr*> imaginaryArguments; List<ParameterDirection> expectedParamDirections; @@ -10660,7 +10827,7 @@ namespace Slang expectedParamDirections.add(ParameterDirection::kParameterDirection_In); } - return {imaginaryArguments, expectedParamDirections}; + return {imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection}; } // This helper function is needed to workaround a gcc bug. @@ -10685,7 +10852,11 @@ namespace Slang higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; if (derivativeOfAttr->args.getCount() > 0) higherOrderFuncExpr->loc = derivativeOfAttr->args[0]->loc; - Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, visitor->allowStaticReferenceToNonStaticMember()); + + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr( + higherOrderFuncExpr, + visitor->allowStaticReferenceToNonStaticMember()); + if (!checkedHigherOrderFuncExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); @@ -10701,7 +10872,15 @@ namespace Slang { auto resolvedFuncExpr = as<HigherOrderInvokeExpr>(resolvedInvoke->functionExpr); if (resolvedFuncExpr) + { calleeDeclRefExpr = as<DeclRefExpr>(resolvedFuncExpr->baseFunction); + if (!calleeDeclRef && as<OverloadedExpr>(resolvedFuncExpr->baseFunction)) + { + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::overloadedFuncUsedWithDerivativeOfAttributes); + } + } } if (!calleeDeclRefExpr) @@ -10729,13 +10908,6 @@ namespace Slang // We may relax this restriction in the future by solving the "inverse" generic arguments // from the `calleeDeclRef`, and use them to create a declRef to funcDecl from the original // func. - auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeFunc)); - auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl)); - if (originalNextGeneric != derivativeNextGeneric) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); - return; - } if (isInterfaceRequirement(calleeFunc)) { @@ -10787,7 +10959,14 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl( + visitor, + funcDecl, + attr, + imaginaryArguments.args, + imaginaryArguments.directions, + imaginaryArguments.thisArg, + imaginaryArguments.thisArgDirection); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) @@ -10798,7 +10977,14 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl( + visitor, + funcDecl, + attr, + imaginaryArguments.args, + imaginaryArguments.directions, + imaginaryArguments.thisArg, + imaginaryArguments.thisArgDirection); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) @@ -10809,7 +10995,14 @@ namespace Slang return; ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); + checkDerivativeAttributeImpl( + visitor, + funcDecl, + attr, + imaginaryArguments.args, + imaginaryArguments.directions, + imaginaryArguments.thisArg, + imaginaryArguments.thisArgDirection); } static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index e9c257750..842ffb527 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4754,7 +4754,9 @@ namespace Slang scope = scope->parent; } - getSink()->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl); + if (auto sink = getSink()) + sink->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl); + return CreateErrorExpr(expr); } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 9a14b71e4..23eba5f03 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -463,8 +463,11 @@ DIAGNOSTIC(31151, Error, cannotResolveGenericArgumentForDerivativeFunction, "[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the generic arguments to the derivatives cannot be automatically deduced.") DIAGNOSTIC(31152, Error, cannotAssociateInterfaceRequirementWithDerivative, "cannot associate an interface requirement with a derivative.") DIAGNOSTIC(31153, Error, cannotUseInterfaceRequirementAsDerivative, "cannot use an interface requirement as a derivative.") -DIAGNOSTIC(31154, Error, customDerivativeSignatureThisParamMismatch, "custom derivative does not match expected signature on `this`. Either both the original and the derivative function are static, or they must have the same `this` type.") +DIAGNOSTIC(31154, Error, customDerivativeSignatureThisParamMismatch, "custom derivative does not match expected signature on `this`. Both original and derivative function must have the same `this` type.") DIAGNOSTIC(31155, Error, customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType, "custom derivative is not allowed for non-static member functions of a differentiable type.") +DIAGNOSTIC(31156, Error, customDerivativeExpectedStatic, "expected a static definition for the custom derivative.") +DIAGNOSTIC(31157, Error, overloadedFuncUsedWithDerivativeOfAttributes, "cannot resolve overloaded functions for derivative-of attributes.") + DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") DIAGNOSTIC(31202, Error, duplicateModifier, "modifier '$0' is redundant or conflicting with existing modifier '$1'") |
