summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp128
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-syntax.h2
3 files changed, 115 insertions, 17 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4e2f146a1..eeb54fc4e 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6918,12 +6918,18 @@ namespace Slang
return val;
}
+ struct ArgsWithDirectionInfo
+ {
+ List<Expr*> args;
+ List<ParameterDirection> directions;
+ };
template<typename TDerivativeAttr>
void checkDerivativeAttributeImpl(
SemanticsVisitor* visitor,
TDerivativeAttr* attr,
- const List<Expr*>& imaginaryArguments)
+ const List<Expr*>& imaginaryArguments,
+ const List<ParameterDirection>& expectedParamDirections)
{
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
@@ -6949,13 +6955,51 @@ namespace Slang
visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction);
return;
}
-
+
+ // If left value is true, then convert the
+ // inner type to an InOutType.
+ //
+ auto qualTypeToString = [&](QualType qualType) -> String
+ {
+ Type* type = qualType.type;
+ if (qualType.isLeftValue)
+ {
+ type = ctx.getASTBuilder()->getInOutType(type);
+ }
+ return type->toString();
+ };
+
auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
auto resolved = subVisitor.ResolveInvoke(invokeExpr);
+
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
{
if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
{
+ // There are two ways to make it to this point.. a proper resolution, and a
+ // resolution that has failed due to type mismatch.
+ // Further, a proper resolution can still be invalid due to incorrect parameter
+ // directionality.
+ // We'll detect both these incorrect cases here and issue an appropriate diagnostic.
+ //
+ auto funcType = as<FuncType>(calleeDeclRef->type);
+ for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii)
+ {
+ // Check if the resolved invoke argument type is an error type.
+ // If so, then we have a type mismatch.
+ //
+ if (resolvedInvoke->arguments[ii]->type.type->equals(ctx.getASTBuilder()->getErrorType()) ||
+ funcType->getParamDirection(ii) != expectedParamDirections[ii])
+ {
+ visitor->getSink()->diagnose(
+ attr,
+ Diagnostics::customDerivativeSignatureMismatchAtPosition,
+ ii,
+ qualTypeToString(imaginaryArguments[ii]->type),
+ funcType->getParamType(ii)->toString());
+ }
+ }
+
attr->funcExpr = calleeDeclRef;
if (attr->args.getCount())
attr->args[0] = attr->funcExpr;
@@ -6963,7 +7007,23 @@ namespace Slang
}
}
- visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ // Build the expected signature from imaginary args to diagnose
+ // when no matching function is found (this excludes the case handled above)
+ //
+ StringBuilder builder;
+ builder << "(";
+ for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii)
+ {
+ if (ii != 0)
+ builder << ", ";
+ if (imaginaryArguments[ii]->type)
+ builder << qualTypeToString(imaginaryArguments[ii]->type);
+ else
+ builder << "<error>";
+ }
+ builder << ")";
+
+ visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeSignatureMismatch, builder.produceString());
}
template<typename TDerivativeAttr>
@@ -6985,9 +7045,10 @@ namespace Slang
return "PrimalSubstitute";
}
- List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
+ ArgsWithDirectionInfo getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
+ List<ParameterDirection> directions;
for (auto param : func->getParameters())
{
auto arg = astBuilder->create<VarExpr>();
@@ -6996,11 +7057,12 @@ namespace Slang
arg->type.type = param->getType();
arg->loc = loc;
imaginaryArguments.add(arg);
+ directions.add(getParameterDirection(param));
}
- return imaginaryArguments;
+ return { imaginaryArguments, directions };
}
- List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
for (auto param : originalFuncDecl->getParameters())
@@ -7019,12 +7081,22 @@ namespace Slang
}
imaginaryArguments.add(arg);
}
- return imaginaryArguments;
+
+ // Copy parameter directions as is.
+ List<ParameterDirection> expectedParamDirections;
+ for (auto param : originalFuncDecl->getParameters())
+ {
+ expectedParamDirections.add(getParameterDirection(param));
+ }
+
+ return { imaginaryArguments, expectedParamDirections };
}
- List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
+ List<ParameterDirection> expectedParamDirections;
+
auto isOutParam = [&](ParamDecl* param)
{
return param->findModifier<OutModifier>() != nullptr
@@ -7038,18 +7110,31 @@ namespace Slang
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
+
+ ParameterDirection direction = getParameterDirection(param);
+
bool isDiffParam = (!param->findModifier<NoDiffModifier>());
if (isDiffParam)
{
if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
{
arg->type.type = pairType;
+ arg->type.isLeftValue = true;
+
if (isOutParam(param))
{
- // out T -> in T.Differential
+ // out T : IDifferentiable -> in T.Differential
arg->type.isLeftValue = false;
arg->type.type = visitor->tryGetDifferentialType(
visitor->getASTBuilder(), pairType->getPrimalType());
+
+ direction = ParameterDirection::kParameterDirection_In;
+ }
+ else
+ {
+ // in T : IDifferentiable -> inout DifferentialPair<T>
+ // inout T : IDifferentiable -> inout DifferentialPair<T>
+ direction = ParameterDirection::kParameterDirection_InOut;
}
}
else
@@ -7064,8 +7149,15 @@ namespace Slang
// Skip non-differentiable out params.
continue;
}
+
+ // no_diff inout T -> in T
+ // no_diff in T -> in T
+ //
+ direction = ParameterDirection::kParameterDirection_In;
}
+
imaginaryArguments.add(arg);
+ expectedParamDirections.add(direction);
}
if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
{
@@ -7074,8 +7166,10 @@ namespace Slang
arg->type.type = diffReturnType;
arg->loc = loc;
imaginaryArguments.add(arg);
+ expectedParamDirections.add(ParameterDirection::kParameterDirection_In);
}
- return imaginaryArguments;
+
+ return {imaginaryArguments, expectedParamDirections};
}
// This helper function is needed to workaround a gcc bug.
@@ -7105,7 +7199,7 @@ namespace Slang
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
return;
}
- List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
+ List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc).args;
auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs);
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
@@ -7171,8 +7265,8 @@ namespace Slang
if (attr->funcExpr->type.type)
return;
- List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
@@ -7182,8 +7276,8 @@ namespace Slang
if (attr->funcExpr->type.type)
return;
- List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr)
@@ -7193,8 +7287,8 @@ namespace Slang
if (attr->funcExpr->type.type)
return;
- List<Expr*> imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
template<typename TDerivativeAttr, typename TDerivativeOfAttr>
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index c01a6ddf1..d38cde4e6 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -362,6 +362,8 @@ DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative att
DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.")
DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function")
+DIAGNOSTIC(31149, Error, customDerivativeSignatureMismatchAtPosition, "invalid custom derivative. parameter type mismatch at position $0. expected '$1', got '$2'")
+DIAGNOSTIC(31150, Error, customDerivativeSignatureMismatch, "invalid custom derivative. could not resolve function with expected signature '$0'")
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 52df3a0e0..65ad121a0 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -319,6 +319,8 @@ namespace Slang
const Substitutions* substs,
InterfaceDecl* interfaceDecl);
+ ParameterDirection getParameterDirection(VarDeclBase* varDecl);
+
enum class UserDefinedAttributeTargets
{
None = 0,