diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 118 |
1 files changed, 98 insertions, 20 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8994eb783..235b57ca6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -223,12 +223,26 @@ Expr* SemanticsVisitor::maybeOpenRef(Expr* expr) { auto exprType = expr->type.type; - if (auto refType = as<RefTypeBase>(exprType)) + if (auto refType = as<ExplicitRefType>(exprType)) { auto openRef = m_astBuilder->create<OpenRefExpr>(); openRef->innerExpr = expr; - openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr); + + // TODO(tfoley): The `QualType` constructor has its own + // logic to determine the value category (e.g., whether + // or not something is an l-value) when it is passed + // a `Ref` type. It is unclear whether both this code + // *and* that code are required, or if we can consolidate + // the two. + // + // Note that here we change the actual `Type*` stored in + // the `QualType` to be the underlying value type of the + // reference, whereas the `QualType` constructor does not + // perform such unwrapping. + // + openRef->type = QualType(refType); openRef->type.type = refType->getValueType(); + openRef->checked = true; openRef->loc = expr->loc; return openRef; @@ -538,10 +552,26 @@ Expr* SemanticsVisitor::constructDerefExpr(Expr* base, QualType elementType, Sou derefExpr->type = QualType(elementType); derefExpr->checked = true; - if (as<PtrType>(base->type) || as<RefType>(base->type)) + if (as<PtrType>(base->type)) { + // TODO(tfoley): It is not clear why this is being unconditionally + // set to `true` when the `Ptr` types in the core module has an + // `AccessQualifier` parameter that can be used to form a read-only pointer. + // derefExpr->type.isLeftValue = true; } + else if (as<ExplicitRefType>(base->type)) + { + // TODO(tfoley): The code here is exploiting the ability of the + // `QualType` constructor to compute the correct value category + // for a reference type, so that we don't have to repeat that logic + // here. That might not be the right place for that logic to live, + // however, and so the code here might need updating sooner or + // later. + // + bool baseIsLVal = QualType(base->type.type).isLeftValue; + derefExpr->type.isLeftValue = baseIsLVal; + } else if (isImmutableBufferType(base->type)) { derefExpr->type.isLeftValue = false; @@ -2925,7 +2955,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) Index paramCount = funcType->getParamCount(); for (Index pp = 0; pp < paramCount; ++pp) { - auto paramType = funcType->getParamType(pp); + auto paramType = funcType->getParamTypeWithDirectionWrapper(pp); Expr* argExpr = nullptr; ParamDecl* paramDecl = nullptr; if (pp < expr->arguments.getCount()) @@ -2936,7 +2966,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) } compareMemoryQualifierOfParamToArgument(paramDecl, argExpr); - if (as<OutTypeBase>(paramType) || as<RefType>(paramType)) + if (as<OutTypeBase>(paramType) || as<RefParamType>(paramType)) { // `out`, `inout`, and `ref` parameters currently require // an *exact* match on the type of the argument. @@ -3047,7 +3077,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) const DiagnosticInfo* diagnostic = nullptr; // Try and determine reason for failure - if (as<RefType>(paramType)) + if (as<RefParamType>(paramType)) { // Ref types are not allowed to use this mechanism because // it breaks atomics @@ -3537,21 +3567,66 @@ Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn) return resultMemberExpr; } -Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) +Type* SemanticsVisitor::_toDifferentialParamType(Type* primalParamType) { - // Check for type modifiers like 'out' and 'inout'. We need to differentiate the - // nested type. + // This function is invoked on parameter types that could + // still be wrapped to represent a parameter-passing mode + // like `ref`, `out`, etc. // - if (auto primalOutType = as<OutType>(primalType)) - { - return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); - } - else if (auto primalInOutType = as<InOutType>(primalType)) + // We need to intercept these cases here, and ensure that + // the wrapper is not exposed to other parts of the front-end + // code, because they only exist to encode the parameter-passing + // mode, and are not a proper part of the Slang type system + // (at least not at this time). + // + if (auto primalParamWrapperType = as<ParamDirectionType>(primalParamType)) { - return m_astBuilder->getInOutType( - _toDifferentialParamType(primalInOutType->getValueType())); + // Some parameter-passing modes do not naturally lend themselves + // to being differentiated - most notably, `ref` parameters. + // We will detect those cases here, and handle them as a parameter + // of a non-differentiable type would be handled. + // + // TODO(tfoley): With the introduction of `IDifferentiablePtrType`, + // it is possible that something like a `ref` parameter could also + // support autodiff, but it is not clear what a correct + // one-size-fits-all behavior should be in that case. + // + if (as<RefParamType>(primalParamType)) + return primalParamWrapperType; + + // Given a primal type that is a wrapper like `Out<T>`, we can + // extract the underlying primal value type `T`, and determine + // what the differential type value type corresponding to `T` + // should be. + // + auto primalValueType = primalParamWrapperType->getValueType(); + auto diffValueType = _toDifferentialParamType(primalValueType); + + // Once we have created the appropriate differential value type, + // we will form the differential parameter type by wrapping + // the differential value type in the same wrapper that had + // been used for the primal type. + // + if (as<OutType>(primalParamWrapperType)) + { + return m_astBuilder->getOutType(diffValueType); + } + else if (as<InOutType>(primalParamWrapperType)) + { + return m_astBuilder->getInOutType(diffValueType); + } + else if (as<ConstRefParamType>(primalParamWrapperType)) + { + return m_astBuilder->getConstRefParamType(diffValueType); + } + else + { + SLANG_UNEXPECTED("unhandled parameter-passing mode"); + UNREACHABLE_RETURN(diffValueType); + } } - return getDifferentialPairType(primalType); + + return getDifferentialPairType(primalParamType); } Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) @@ -3632,7 +3707,8 @@ Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) for (Index i = 0; i < originalType->getParamCount(); i++) { - if (auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) + if (auto jvpParamType = + _toDifferentialParamType(originalType->getParamTypeWithDirectionWrapper(i))) paramTypes.add(jvpParamType); } FuncType* jvpType = @@ -3658,7 +3734,9 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) for (Index i = 0; i < originalType->getParamCount(); i++) { - if (auto outType = as<OutType>(originalType->getParamType(i))) + auto originalParamType = originalType->getParamTypeWithDirectionWrapper(i); + + if (auto outType = as<OutType>(originalParamType)) { auto diffElementType = tryGetDifferentialType(m_astBuilder, outType->getValueType()); if (diffElementType) @@ -3670,7 +3748,7 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) continue; } } - else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + else if (auto derivType = _toDifferentialParamType(originalParamType)) { if (as<DifferentialPairType>(derivType)) { |
