diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-01 08:46:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-01 08:46:57 -0700 |
| commit | cbc1eff56057f199183bb7c17d8a360326512367 (patch) | |
| tree | 487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-check-expr.cpp | |
| parent | b707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff) | |
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 114 |
1 files changed, 90 insertions, 24 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fe37f5099..251849ede 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1410,6 +1410,19 @@ namespace Slang return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value); } + if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>()) + { + if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type) + { + auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts()); + if (auto arrayType = as<ArrayExpressionType>(type)) + { + if (auto val = as<IntVal>(arrayType->arrayLength)) + return val; + } + } + } + // it is possible that we are referring to a generic value param if (auto declRefExpr = expr.as<DeclRefExpr>()) { @@ -1871,14 +1884,42 @@ namespace Slang arg = CheckTerm(arg); } - return CheckInvokeExprWithCheckedOperands(expr); + // If we are in a differentiable function, register differential witness tables involved in + // this call. + if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>()) + { + for (auto& arg : expr->arguments) + { + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } + + auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); + + if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>()) + { + if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr)) + { + // Register types for final resolved invoke arguments again. + for (auto& arg : expr->arguments) + { + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } + maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); + } + return checkedExpr; } Expr* SemanticsExprVisitor::visitVarExpr(VarExpr *expr) { // If we've already resolved this expression, don't try again. if (expr->declRef) + { + if (!expr->type) + expr->type = GetTypeForDeclRef(expr->declRef, expr->loc); return expr; + } expr->type = QualType(m_astBuilder->getErrorType()); auto lookupResult = lookUp( @@ -1908,63 +1949,56 @@ namespace Slang return expr; } - Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) + Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) { // Check for type modifiers like 'out' and 'inout'. We need to differentiate the // nested type. // if (auto primalOutType = as<OutType>(primalType)) { - return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType())); + return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); } else if (auto primalInOutType = as<InOutType>(primalType)) { - return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType())); + return m_astBuilder->getInOutType(_toDifferentialParamType(primalInOutType->getValueType())); } + return getDifferentialPairType(primalType); + } + Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) + { // Get a reference to the builtin 'IDifferentiable' interface - auto differentiableInterface = builder->getDifferentiableInterface(); + auto differentiableInterface = m_astBuilder->getDifferentiableInterface(); + auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface))) - return builder->getDifferentialPairType(primalType, conformanceWitness); + if (conformanceWitness) + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); else return primalType; - } - Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType) - { - if (auto conformanceWitness = - as<Witness>(tryGetInterfaceConformanceWitness( - primalType, - builder->getDifferentiableInterface()))) - return builder->getDifferentialPairType(primalType, conformanceWitness); - else - return primalType; - } - - Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType) + Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType) { // Resolve JVP type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - FuncType* jvpType = builder->create<FuncType>(); + FuncType* jvpType = m_astBuilder->create<FuncType>(); // The JVP return type is float if primal return type is float // void otherwise. // - jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType()); + jvpType->resultType = getDifferentialPairType(originalType->getResultType()); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType())); + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); jvpType->errorType = originalType->errorType; for (UInt i = 0; i < originalType->getParamCount(); i++) { - if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i))) + if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) jvpType->paramTypes.add(jvpParamType); } @@ -1978,6 +2012,15 @@ namespace Slang // Check/Resolve inner function declaration. expr->baseFunction = CheckTerm(expr->baseFunction); + // Register parameter types. + if (auto funcType = as<FuncType>(expr->baseFunction->type.type)) + { + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i)); + } + } + // For now we only support using higher order expr as callee in an invoke expr. // The actual type of the higher order function will be derived during resolve invoke. expr->type = m_astBuilder->getBottomType(); @@ -1985,6 +2028,29 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) + { + expr->arrayExpr = CheckTerm(expr->arrayExpr); + if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type)) + { + expr->type = m_astBuilder->getIntType(); + if (!arrType->arrayLength) + { + getSink()->diagnose(expr, Diagnostics::invalidArraySize); + } + } + else + { + if (!as<ErrorType>(expr->arrayExpr->type)) + { + getSink()->diagnose( + expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type); + } + expr->type = m_astBuilder->getErrorType(); + } + return expr; + } + Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr) { // Check the term we are applying first |
