summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-01 08:46:57 -0700
committerGitHub <noreply@github.com>2022-11-01 08:46:57 -0700
commitcbc1eff56057f199183bb7c17d8a360326512367 (patch)
tree487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-check-expr.cpp
parentb707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff)
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp114
1 files changed, 90 insertions, 24 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index fe37f5099..251849ede 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1410,6 +1410,19 @@ namespace Slang
return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value);
}
+ if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>())
+ {
+ if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type)
+ {
+ auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts());
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ if (auto val = as<IntVal>(arrayType->arrayLength))
+ return val;
+ }
+ }
+ }
+
// it is possible that we are referring to a generic value param
if (auto declRefExpr = expr.as<DeclRefExpr>())
{
@@ -1871,14 +1884,42 @@ namespace Slang
arg = CheckTerm(arg);
}
- return CheckInvokeExprWithCheckedOperands(expr);
+ // If we are in a differentiable function, register differential witness tables involved in
+ // this call.
+ if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>())
+ {
+ for (auto& arg : expr->arguments)
+ {
+ maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
+ }
+ }
+
+ auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr);
+
+ if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>())
+ {
+ if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr))
+ {
+ // Register types for final resolved invoke arguments again.
+ for (auto& arg : expr->arguments)
+ {
+ maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
+ }
+ }
+ maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type);
+ }
+ return checkedExpr;
}
Expr* SemanticsExprVisitor::visitVarExpr(VarExpr *expr)
{
// If we've already resolved this expression, don't try again.
if (expr->declRef)
+ {
+ if (!expr->type)
+ expr->type = GetTypeForDeclRef(expr->declRef, expr->loc);
return expr;
+ }
expr->type = QualType(m_astBuilder->getErrorType());
auto lookupResult = lookUp(
@@ -1908,63 +1949,56 @@ namespace Slang
return expr;
}
- Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType)
{
// Check for type modifiers like 'out' and 'inout'. We need to differentiate the
// nested type.
//
if (auto primalOutType = as<OutType>(primalType))
{
- return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType()));
+ return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType()));
}
else if (auto primalInOutType = as<InOutType>(primalType))
{
- return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType()));
+ return m_astBuilder->getInOutType(_toDifferentialParamType(primalInOutType->getValueType()));
}
+ return getDifferentialPairType(primalType);
+ }
+ Type* SemanticsVisitor::getDifferentialPairType(Type* primalType)
+ {
// Get a reference to the builtin 'IDifferentiable' interface
- auto differentiableInterface = builder->getDifferentiableInterface();
+ auto differentiableInterface = m_astBuilder->getDifferentiableInterface();
+ auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
- if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)))
- return builder->getDifferentialPairType(primalType, conformanceWitness);
+ if (conformanceWitness)
+ return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
else
return primalType;
-
}
- Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType)
- {
- if (auto conformanceWitness =
- as<Witness>(tryGetInterfaceConformanceWitness(
- primalType,
- builder->getDifferentiableInterface())))
- return builder->getDifferentialPairType(primalType, conformanceWitness);
- else
- return primalType;
- }
-
- Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType)
+ Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType)
{
// Resolve JVP type here.
// Note that this type checking needs to be in sync with
// the auto-generation logic in slang-ir-jvp-diff.cpp
- FuncType* jvpType = builder->create<FuncType>();
+ FuncType* jvpType = m_astBuilder->create<FuncType>();
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType());
+ jvpType->resultType = getDifferentialPairType(originalType->getResultType());
// No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType()));
+ SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
jvpType->errorType = originalType->errorType;
for (UInt i = 0; i < originalType->getParamCount(); i++)
{
- if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i)))
+ if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i)))
jvpType->paramTypes.add(jvpParamType);
}
@@ -1978,6 +2012,15 @@ namespace Slang
// Check/Resolve inner function declaration.
expr->baseFunction = CheckTerm(expr->baseFunction);
+ // Register parameter types.
+ if (auto funcType = as<FuncType>(expr->baseFunction->type.type))
+ {
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i));
+ }
+ }
+
// For now we only support using higher order expr as callee in an invoke expr.
// The actual type of the higher order function will be derived during resolve invoke.
expr->type = m_astBuilder->getBottomType();
@@ -1985,6 +2028,29 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
+ {
+ expr->arrayExpr = CheckTerm(expr->arrayExpr);
+ if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type))
+ {
+ expr->type = m_astBuilder->getIntType();
+ if (!arrType->arrayLength)
+ {
+ getSink()->diagnose(expr, Diagnostics::invalidArraySize);
+ }
+ }
+ else
+ {
+ if (!as<ErrorType>(expr->arrayExpr->type))
+ {
+ getSink()->diagnose(
+ expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type);
+ }
+ expr->type = m_astBuilder->getErrorType();
+ }
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr)
{
// Check the term we are applying first