summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp62
1 files changed, 39 insertions, 23 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2853c1eb9..d99114e4f 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -853,11 +853,11 @@ namespace Slang
}
else if (auto arrayType = as<ArrayExpressionType>(type))
{
- auto baseDiffType = tryGetDifferentialType(builder, arrayType->baseType);
+ auto baseDiffType = tryGetDifferentialType(builder, arrayType->getElementType());
if (!baseDiffType) return nullptr;
return builder->getArrayType(
baseDiffType,
- arrayType->arrayLength);
+ arrayType->getElementCount());
}
if (auto declRefType = as<DeclRefType>(type))
@@ -946,8 +946,8 @@ namespace Slang
if (auto arrayType = as<ArrayExpressionType>(type))
{
- maybeRegisterDifferentiableType(builder, arrayType->baseType);
- return;
+ maybeRegisterDifferentiableType(builder, arrayType->getElementType());
+ // Fall through to register the array type itself.
}
if (auto declRefType = as<DeclRefType>(type))
@@ -990,8 +990,8 @@ namespace Slang
if (auto arrayType = as<ArrayExpressionType>(type))
{
- maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet);
- return;
+ maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet);
+ // Fall through to register the array type itself.
}
if (auto declRefType = as<DeclRefType>(type))
@@ -1204,7 +1204,7 @@ namespace Slang
IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr)
{
- return m_astBuilder->getOrCreate<ConstantIntVal>(expr->type.type, expr->value);
+ return m_astBuilder->getIntVal(expr->type.type, expr->value);
}
IntVal* SemanticsVisitor::tryConstantFoldExpr(
@@ -1433,7 +1433,7 @@ namespace Slang
}
}
- IntVal* result = m_astBuilder->getOrCreate<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue);
+ IntVal* result = m_astBuilder->getIntVal(invokeExpr.getExpr()->type.type, resultValue);
return result;
}
@@ -1517,7 +1517,7 @@ namespace Slang
{
// If it's a boolean, we allow promotion to int.
const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value);
- return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value);
+ return m_astBuilder->getIntVal(m_astBuilder->getBoolType(), value);
}
if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>())
@@ -1527,8 +1527,11 @@ namespace Slang
auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts());
if (auto arrayType = as<ArrayExpressionType>(type))
{
- if (auto val = as<IntVal>(arrayType->arrayLength))
- return val;
+ if (!arrayType->isUnsized())
+ {
+ if (auto val = as<IntVal>(arrayType->getElementCount()))
+ return val;
+ }
}
}
}
@@ -1734,7 +1737,7 @@ namespace Slang
{
return CheckSimpleSubscriptExpr(
subscriptExpr,
- baseArrayType->baseType);
+ baseArrayType->getElementType());
}
else if (auto vecType = as<VectorExpressionType>(baseType))
{
@@ -2146,12 +2149,14 @@ namespace Slang
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = m_astBuilder->getDifferentiableInterface();
-
- auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface));
+
+ SubtypeWitness* conformanceWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
if (conformanceWitness)
+ {
return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ }
else
return primalType;
}
@@ -2200,15 +2205,24 @@ namespace Slang
for (UInt i = 0; i < originalType->getParamCount(); i++)
{
- if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ if (auto outType = as<OutType>(originalType->getParamType(i)))
{
- // Using inout type on all the derivative parameters
- if (auto outType = as<OutType>(derivType))
+ auto diffElementType =
+ tryGetDifferentialType(m_astBuilder, outType->getValueType());
+ if (diffElementType)
+ {
+ type->paramTypes.add(diffElementType);
+ }
+ else
{
- derivType = outType->getValueType();
+ continue;
}
- else if (as<DifferentialPairType>(derivType))
+ }
+ else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ {
+ if (as<DifferentialPairType>(derivType))
{
+ // Using inout type on all the derivative parameters
derivType = m_astBuilder->getInOutType(derivType);
}
type->paramTypes.add(derivType);
@@ -2216,7 +2230,9 @@ namespace Slang
}
// Last parameter is the initial derivative of the original return type
- type->paramTypes.add(getDifferentialType(m_astBuilder, originalType->resultType, SourceLoc()));
+ auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType);
+ if (dOutType)
+ type->paramTypes.add(dOutType);
return type;
}
@@ -2407,7 +2423,7 @@ namespace Slang
if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type))
{
expr->type = m_astBuilder->getIntType();
- if (!arrType->arrayLength)
+ if (arrType->isUnsized())
{
getSink()->diagnose(expr, Diagnostics::invalidArraySize);
}
@@ -2823,7 +2839,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
@@ -2948,7 +2964,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there