diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 62 |
1 files changed, 39 insertions, 23 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2853c1eb9..d99114e4f 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -853,11 +853,11 @@ namespace Slang } else if (auto arrayType = as<ArrayExpressionType>(type)) { - auto baseDiffType = tryGetDifferentialType(builder, arrayType->baseType); + auto baseDiffType = tryGetDifferentialType(builder, arrayType->getElementType()); if (!baseDiffType) return nullptr; return builder->getArrayType( baseDiffType, - arrayType->arrayLength); + arrayType->getElementCount()); } if (auto declRefType = as<DeclRefType>(type)) @@ -946,8 +946,8 @@ namespace Slang if (auto arrayType = as<ArrayExpressionType>(type)) { - maybeRegisterDifferentiableType(builder, arrayType->baseType); - return; + maybeRegisterDifferentiableType(builder, arrayType->getElementType()); + // Fall through to register the array type itself. } if (auto declRefType = as<DeclRefType>(type)) @@ -990,8 +990,8 @@ namespace Slang if (auto arrayType = as<ArrayExpressionType>(type)) { - maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet); - return; + maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet); + // Fall through to register the array type itself. } if (auto declRefType = as<DeclRefType>(type)) @@ -1204,7 +1204,7 @@ namespace Slang IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr) { - return m_astBuilder->getOrCreate<ConstantIntVal>(expr->type.type, expr->value); + return m_astBuilder->getIntVal(expr->type.type, expr->value); } IntVal* SemanticsVisitor::tryConstantFoldExpr( @@ -1433,7 +1433,7 @@ namespace Slang } } - IntVal* result = m_astBuilder->getOrCreate<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue); + IntVal* result = m_astBuilder->getIntVal(invokeExpr.getExpr()->type.type, resultValue); return result; } @@ -1517,7 +1517,7 @@ namespace Slang { // If it's a boolean, we allow promotion to int. const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value); - return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value); + return m_astBuilder->getIntVal(m_astBuilder->getBoolType(), value); } if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>()) @@ -1527,8 +1527,11 @@ namespace Slang auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts()); if (auto arrayType = as<ArrayExpressionType>(type)) { - if (auto val = as<IntVal>(arrayType->arrayLength)) - return val; + if (!arrayType->isUnsized()) + { + if (auto val = as<IntVal>(arrayType->getElementCount())) + return val; + } } } } @@ -1734,7 +1737,7 @@ namespace Slang { return CheckSimpleSubscriptExpr( subscriptExpr, - baseArrayType->baseType); + baseArrayType->getElementType()); } else if (auto vecType = as<VectorExpressionType>(baseType)) { @@ -2146,12 +2149,14 @@ namespace Slang // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = m_astBuilder->getDifferentiableInterface(); - - auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); + + SubtypeWitness* conformanceWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. if (conformanceWitness) + { return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } else return primalType; } @@ -2200,15 +2205,24 @@ namespace Slang for (UInt i = 0; i < originalType->getParamCount(); i++) { - if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + if (auto outType = as<OutType>(originalType->getParamType(i))) { - // Using inout type on all the derivative parameters - if (auto outType = as<OutType>(derivType)) + auto diffElementType = + tryGetDifferentialType(m_astBuilder, outType->getValueType()); + if (diffElementType) + { + type->paramTypes.add(diffElementType); + } + else { - derivType = outType->getValueType(); + continue; } - else if (as<DifferentialPairType>(derivType)) + } + else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + if (as<DifferentialPairType>(derivType)) { + // Using inout type on all the derivative parameters derivType = m_astBuilder->getInOutType(derivType); } type->paramTypes.add(derivType); @@ -2216,7 +2230,9 @@ namespace Slang } // Last parameter is the initial derivative of the original return type - type->paramTypes.add(getDifferentialType(m_astBuilder, originalType->resultType, SourceLoc())); + auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType); + if (dOutType) + type->paramTypes.add(dOutType); return type; } @@ -2407,7 +2423,7 @@ namespace Slang if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type)) { expr->type = m_astBuilder->getIntType(); - if (!arrType->arrayLength) + if (arrType->isUnsized()) { getSink()->diagnose(expr, Diagnostics::invalidArraySize); } @@ -2823,7 +2839,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there @@ -2948,7 +2964,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there |
