summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp118
1 files changed, 98 insertions, 20 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 8994eb783..235b57ca6 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -223,12 +223,26 @@ Expr* SemanticsVisitor::maybeOpenRef(Expr* expr)
{
auto exprType = expr->type.type;
- if (auto refType = as<RefTypeBase>(exprType))
+ if (auto refType = as<ExplicitRefType>(exprType))
{
auto openRef = m_astBuilder->create<OpenRefExpr>();
openRef->innerExpr = expr;
- openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr);
+
+ // TODO(tfoley): The `QualType` constructor has its own
+ // logic to determine the value category (e.g., whether
+ // or not something is an l-value) when it is passed
+ // a `Ref` type. It is unclear whether both this code
+ // *and* that code are required, or if we can consolidate
+ // the two.
+ //
+ // Note that here we change the actual `Type*` stored in
+ // the `QualType` to be the underlying value type of the
+ // reference, whereas the `QualType` constructor does not
+ // perform such unwrapping.
+ //
+ openRef->type = QualType(refType);
openRef->type.type = refType->getValueType();
+
openRef->checked = true;
openRef->loc = expr->loc;
return openRef;
@@ -538,10 +552,26 @@ Expr* SemanticsVisitor::constructDerefExpr(Expr* base, QualType elementType, Sou
derefExpr->type = QualType(elementType);
derefExpr->checked = true;
- if (as<PtrType>(base->type) || as<RefType>(base->type))
+ if (as<PtrType>(base->type))
{
+ // TODO(tfoley): It is not clear why this is being unconditionally
+ // set to `true` when the `Ptr` types in the core module has an
+ // `AccessQualifier` parameter that can be used to form a read-only pointer.
+ //
derefExpr->type.isLeftValue = true;
}
+ else if (as<ExplicitRefType>(base->type))
+ {
+ // TODO(tfoley): The code here is exploiting the ability of the
+ // `QualType` constructor to compute the correct value category
+ // for a reference type, so that we don't have to repeat that logic
+ // here. That might not be the right place for that logic to live,
+ // however, and so the code here might need updating sooner or
+ // later.
+ //
+ bool baseIsLVal = QualType(base->type.type).isLeftValue;
+ derefExpr->type.isLeftValue = baseIsLVal;
+ }
else if (isImmutableBufferType(base->type))
{
derefExpr->type.isLeftValue = false;
@@ -2925,7 +2955,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr)
Index paramCount = funcType->getParamCount();
for (Index pp = 0; pp < paramCount; ++pp)
{
- auto paramType = funcType->getParamType(pp);
+ auto paramType = funcType->getParamTypeWithDirectionWrapper(pp);
Expr* argExpr = nullptr;
ParamDecl* paramDecl = nullptr;
if (pp < expr->arguments.getCount())
@@ -2936,7 +2966,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr)
}
compareMemoryQualifierOfParamToArgument(paramDecl, argExpr);
- if (as<OutTypeBase>(paramType) || as<RefType>(paramType))
+ if (as<OutTypeBase>(paramType) || as<RefParamType>(paramType))
{
// `out`, `inout`, and `ref` parameters currently require
// an *exact* match on the type of the argument.
@@ -3047,7 +3077,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr)
const DiagnosticInfo* diagnostic = nullptr;
// Try and determine reason for failure
- if (as<RefType>(paramType))
+ if (as<RefParamType>(paramType))
{
// Ref types are not allowed to use this mechanism because
// it breaks atomics
@@ -3537,21 +3567,66 @@ Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn)
return resultMemberExpr;
}
-Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType)
+Type* SemanticsVisitor::_toDifferentialParamType(Type* primalParamType)
{
- // Check for type modifiers like 'out' and 'inout'. We need to differentiate the
- // nested type.
+ // This function is invoked on parameter types that could
+ // still be wrapped to represent a parameter-passing mode
+ // like `ref`, `out`, etc.
//
- if (auto primalOutType = as<OutType>(primalType))
- {
- return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType()));
- }
- else if (auto primalInOutType = as<InOutType>(primalType))
+ // We need to intercept these cases here, and ensure that
+ // the wrapper is not exposed to other parts of the front-end
+ // code, because they only exist to encode the parameter-passing
+ // mode, and are not a proper part of the Slang type system
+ // (at least not at this time).
+ //
+ if (auto primalParamWrapperType = as<ParamDirectionType>(primalParamType))
{
- return m_astBuilder->getInOutType(
- _toDifferentialParamType(primalInOutType->getValueType()));
+ // Some parameter-passing modes do not naturally lend themselves
+ // to being differentiated - most notably, `ref` parameters.
+ // We will detect those cases here, and handle them as a parameter
+ // of a non-differentiable type would be handled.
+ //
+ // TODO(tfoley): With the introduction of `IDifferentiablePtrType`,
+ // it is possible that something like a `ref` parameter could also
+ // support autodiff, but it is not clear what a correct
+ // one-size-fits-all behavior should be in that case.
+ //
+ if (as<RefParamType>(primalParamType))
+ return primalParamWrapperType;
+
+ // Given a primal type that is a wrapper like `Out<T>`, we can
+ // extract the underlying primal value type `T`, and determine
+ // what the differential type value type corresponding to `T`
+ // should be.
+ //
+ auto primalValueType = primalParamWrapperType->getValueType();
+ auto diffValueType = _toDifferentialParamType(primalValueType);
+
+ // Once we have created the appropriate differential value type,
+ // we will form the differential parameter type by wrapping
+ // the differential value type in the same wrapper that had
+ // been used for the primal type.
+ //
+ if (as<OutType>(primalParamWrapperType))
+ {
+ return m_astBuilder->getOutType(diffValueType);
+ }
+ else if (as<InOutType>(primalParamWrapperType))
+ {
+ return m_astBuilder->getInOutType(diffValueType);
+ }
+ else if (as<ConstRefParamType>(primalParamWrapperType))
+ {
+ return m_astBuilder->getConstRefParamType(diffValueType);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unhandled parameter-passing mode");
+ UNREACHABLE_RETURN(diffValueType);
+ }
}
- return getDifferentialPairType(primalType);
+
+ return getDifferentialPairType(primalParamType);
}
Type* SemanticsVisitor::getDifferentialPairType(Type* primalType)
@@ -3632,7 +3707,8 @@ Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)
for (Index i = 0; i < originalType->getParamCount(); i++)
{
- if (auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i)))
+ if (auto jvpParamType =
+ _toDifferentialParamType(originalType->getParamTypeWithDirectionWrapper(i)))
paramTypes.add(jvpParamType);
}
FuncType* jvpType =
@@ -3658,7 +3734,9 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType)
for (Index i = 0; i < originalType->getParamCount(); i++)
{
- if (auto outType = as<OutType>(originalType->getParamType(i)))
+ auto originalParamType = originalType->getParamTypeWithDirectionWrapper(i);
+
+ if (auto outType = as<OutType>(originalParamType))
{
auto diffElementType = tryGetDifferentialType(m_astBuilder, outType->getValueType());
if (diffElementType)
@@ -3670,7 +3748,7 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType)
continue;
}
}
- else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ else if (auto derivType = _toDifferentialParamType(originalParamType))
{
if (as<DifferentialPairType>(derivType))
{