summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 19:24:09 -0800
committerGitHub <noreply@github.com>2023-01-30 19:24:09 -0800
commit499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch)
tree4c570a36d305c8909d633183694e0d1225f044c2 /source
parent134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff)
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang6
-rw-r--r--source/slang/diff.meta.slang35
-rw-r--r--source/slang/slang-ast-builder.cpp15
-rw-r--r--source/slang/slang-ast-builder.h5
-rw-r--r--source/slang/slang-ast-type.cpp61
-rw-r--r--source/slang/slang-ast-type.h21
-rw-r--r--source/slang/slang-ast-val.cpp42
-rw-r--r--source/slang/slang-ast-val.h18
-rw-r--r--source/slang/slang-check-conformance.cpp1
-rw-r--r--source/slang/slang-check-conversion.cpp9
-rw-r--r--source/slang/slang-check-decl.cpp23
-rw-r--r--source/slang/slang-check-expr.cpp62
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-shader.cpp4
-rw-r--r--source/slang/slang-check-stmt.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp35
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-autodiff.cpp26
-rw-r--r--source/slang/slang-ir-autodiff.h24
-rw-r--r--source/slang/slang-lower-to-ir.cpp27
-rw-r--r--source/slang/slang-mangle.cpp4
-rw-r--r--source/slang/slang-parameter-binding.cpp8
-rw-r--r--source/slang/slang-reflection-api.cpp16
-rw-r--r--source/slang/slang-syntax.cpp16
-rw-r--r--source/slang/slang-syntax.h2
-rw-r--r--source/slang/slang-type-layout.cpp4
26 files changed, 231 insertions, 240 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 97b9a227d..31dd5ed29 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -220,6 +220,12 @@ interface __FlagsEnumType : __EnumType
{
};
+__generic<T, let N:int>
+__magic_type(ArrayType)
+struct Array
+{
+}
+
// The "comma operator" is effectively just a generic function that returns its second
// argument. The left-to-right evaluation order guaranteed by Slang then ensures that
// `left` is evaluated before `right`.
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c732d1a5e..adbf8ae48 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -127,6 +127,41 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T
p = DifferentialPair<T>(newPrimal, newDiff);
}
+__generic<T, let N:int>
+__intrinsic_op($(kIROp_MakeArrayFromElement))
+Array<T,N> makeArrayFromElement(T element);
+
+
+__generic<T:IDifferentiable, let N:int>
+extension Array<T, N> : IDifferentiable
+{
+ typedef Array<T.Differential, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return makeArrayFromElement<T.Differential, N>(T.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ Array<T.Differential, N> result;
+ for (int i = 0; i < N; i++)
+ result[i] = T.dadd(a[i], b[i]);
+ return result;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ Array<T.Differential, N> result;
+ for (int i = 0; i < N; i++)
+ result[i] = T.dmul(a[i], b[i]);
+ return result;
+ }
+}
+
// vector-matrix
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index ab161065d..03725901e 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -290,13 +290,18 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
{
- ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount);
- if (!arrayType->baseType)
+ if (!elementCount)
+ elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
+
+ auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount);
+ if (!result->declRef.decl)
{
- arrayType->baseType = elementType;
- arrayType->arrayLength = elementCount;
+ auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType"));
+ auto arrayTypeDecl = arrayGenericDecl->inner;
+ auto substitutions = getOrCreate<GenericSubstitution>(arrayGenericDecl, elementType, elementCount);
+ result->declRef = DeclRef<Decl>(arrayTypeDecl, substitutions);
}
- return arrayType;
+ return result;
}
VectorExpressionType* ASTBuilder::getVectorType(
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index bc95ed63d..d44a15813 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -230,6 +230,11 @@ public:
});
}
+ ConstantIntVal* getIntVal(Type* type, IntegerLiteralValue value)
+ {
+ return getOrCreate<ConstantIntVal>(type, value);
+ }
+
DeclRefType* getOrCreateDeclRefType(Decl* decl, Substitutions* outer)
{
NodeDesc desc;
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 3801b99b0..362503a64 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -314,64 +314,37 @@ Type* MatrixExpressionType::getRowType()
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ArrayExpressionType::_equalsImplOverride(Type* type)
+Type* ArrayExpressionType::getElementType()
{
- auto arrType = as<ArrayExpressionType>(type);
- if (!arrType)
- return false;
- return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType));
-}
-
-Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
- auto elementType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff));
- IntVal* newArrayLength = nullptr;
- if (arrayLength)
- {
- newArrayLength = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff));
- SLANG_ASSERT(newArrayLength);
- }
- if (diff)
- {
- *ioDiff = 1;
- auto rsType = getArrayType(
- astBuilder,
- elementType,
- newArrayLength);
- return rsType;
- }
- return this;
-}
-
-Type* ArrayExpressionType::_createCanonicalTypeOverride()
-{
- auto canonicalElementType = baseType->getCanonicalType();
- auto canonicalArrayType = getASTBuilder()->getArrayType(
- canonicalElementType,
- arrayLength);
- return canonicalArrayType;
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
}
-HashCode ArrayExpressionType::_getHashCodeOverride()
+IntVal* ArrayExpressionType::getElementCount()
{
- if (arrayLength)
- return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode();
- else
- return baseType->getHashCode();
+ return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]);
}
void ArrayExpressionType::_toTextOverride(StringBuilder& out)
{
- out << baseType;
+ out << getElementType();
out.appendChar('[');
- if (arrayLength)
+ if (!isUnsized())
{
- out << arrayLength;
+ out << getElementCount();
}
out.appendChar(']');
}
+bool ArrayExpressionType::isUnsized()
+{
+ if (auto constSize = as<ConstantIntVal>(getElementCount()))
+ {
+ if (constSize->value == kUnsizedArrayMagicLength)
+ return true;
+ }
+ return false;
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void TypeType::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 0e7614dd6..47608405a 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -432,19 +432,18 @@ class ParameterBlockType : public UniformParameterGroupType
SLANG_AST_CLASS(ParameterBlockType)
};
-class ArrayExpressionType : public Type
+class ArrayExpressionType : public DeclRefType
{
SLANG_AST_CLASS(ArrayExpressionType)
-
- Type* baseType = nullptr;
- IntVal* arrayLength = nullptr;
-
- // Overrides should be public so base classes can access
+ ArrayExpressionType(Type* inElementType, IntVal* inElementCount)
+ {
+ SLANG_UNUSED(inElementType);
+ SLANG_UNUSED(inElementCount);
+ }
+ bool isUnsized();
void _toTextOverride(StringBuilder& out);
- Type* _createCanonicalTypeOverride();
- bool _equalsImplOverride(Type* type);
- Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- HashCode _getHashCodeOverride();
+ Type* getElementType();
+ IntVal* getElementCount();
};
// The "type" of an expression that resolves to a type.
@@ -500,7 +499,7 @@ class VectorExpressionType : public ArithmeticExpressionType
BasicExpressionType* _getScalarTypeOverride();
VectorExpressionType(Type* inElementType, IntVal* inElementCount)
- : elementType(inElementType), elementCount(inElementCount)
+ : elementType(inElementType), elementCount(inElementCount)
{}
};
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index e60c963a8..fde31c730 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -544,7 +544,6 @@ Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBu
return substValue;
}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val)
@@ -618,41 +617,6 @@ Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
return substWitness;
}
-bool DifferentialBottomSubtypeWitness::_equalsValOverride(Val* val)
-{
- auto otherDiffBottomWitness = as<DifferentialBottomSubtypeWitness>(val);
- if (!otherDiffBottomWitness)
- return false;
-
- return otherDiffBottomWitness->sub && otherDiffBottomWitness->sub->equals(sub);
-}
-
-void DifferentialBottomSubtypeWitness::_toTextOverride(StringBuilder& out)
-{
- out << "DifferentialBottomSubtypeWitness(" << sub << ")";
-}
-
-HashCode DifferentialBottomSubtypeWitness::_getHashCodeOverride()
-{
- return combineHash(3892, sub->getHashCode());
-}
-
-Val* DifferentialBottomSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
-
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
- if (!diff)
- return this;
-
- *ioDiff += diff;
-
- DifferentialBottomSubtypeWitness* substWitness =
- astBuilder->create<DifferentialBottomSubtypeWitness>(substSub, substSup);
- return substWitness;
-}
-
bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val)
{
if (auto other = as<ConjunctionSubtypeWitness>(val))
@@ -940,7 +904,7 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
*ioDiff += diff;
if (evaluatedTerms.getCount() == 0)
- return astBuilder->getOrCreate<ConstantIntVal>(type, evaluatedConstantTerm);
+ return astBuilder->getIntVal(type, evaluatedConstantTerm);
if (diff != 0)
{
auto newPolynomial = astBuilder->create<PolynomialIntVal>(type);
@@ -1253,7 +1217,7 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
return terms[0]->paramFactors[0]->param;
}
if (terms.getCount() == 0)
- return builder->getOrCreate<ConstantIntVal>(type, constantTerm);
+ return builder->getIntVal(type, constantTerm);
return this;
}
@@ -1425,7 +1389,7 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
{
SLANG_UNREACHABLE("constant folding of FuncCallIntVal");
}
- return astBuilder->getOrCreate<ConstantIntVal>(resultType, resultValue);
+ return astBuilder->getIntVal(resultType, resultValue);
}
return nullptr;
}
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 8e5192536..49aec8c5e 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -399,24 +399,6 @@ class DynamicSubtypeWitness : public SubtypeWitness
SLANG_AST_CLASS(DynamicSubtypeWitness)
};
- /// A witness of the fact that any type can be viewed as a subtype of DifferentialBottom.
-class DifferentialBottomSubtypeWitness : public SubtypeWitness
-{
- SLANG_AST_CLASS(DifferentialBottomSubtypeWitness)
-
- DifferentialBottomSubtypeWitness(Type* inSub, Type* inSup)
- {
- sub = inSub;
- sup = inSup;
- }
-
- // Overrides should be public so base classes can access
- bool _equalsValOverride(Val* val);
- void _toTextOverride(StringBuilder& out);
- HashCode _getHashCodeOverride();
- Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
-};
-
/// A witness that `T : L & R` because `T : L` and `T : R`
class ConjunctionSubtypeWitness : public SubtypeWitness
{
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 3a50897de..9bf343876 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -176,7 +176,6 @@ namespace Slang
*link = extractWitness;
link = (SubtypeWitness**) &extractWitness->conjunctionWitness;
}
-
// Move on with the list.
bb = bb->prev;
}
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index c6daf5e86..6decce625 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -275,10 +275,11 @@ namespace Slang
// TODO(tfoley): If we can compute the size of the array statically,
// then we want to check that there aren't too many initializers present
- auto toElementType = toArrayType->baseType;
-
- if(auto toElementCount = toArrayType->arrayLength)
+ auto toElementType = toArrayType->getElementType();
+ if(!toArrayType->isUnsized())
{
+ auto toElementCount = toArrayType->getElementCount();
+
// In the case of a sized array, we need to check that the number
// of elements being initialized matches what was declared.
//
@@ -349,7 +350,7 @@ namespace Slang
// We have a new type for the conversion, based on what
// we learned.
toType = m_astBuilder->getArrayType(toElementType,
- m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount));
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount));
}
}
else if(auto toMatrixType = as<MatrixExpressionType>(toType))
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5add89312..9bda6c3e7 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -955,17 +955,14 @@ namespace Slang
}
}
- static bool isUnsizedArrayType(Type* type)
+ bool isUnsizedArrayType(Type* type)
{
// Not an array?
auto arrayType = as<ArrayExpressionType>(type);
if (!arrayType) return false;
// Explicit element count given?
- auto elementCount = arrayType->arrayLength;
- if (elementCount) return true;
-
- return true;
+ return arrayType->isUnsized();
}
bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state)
@@ -3304,7 +3301,7 @@ namespace Slang
{
arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
}
- auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1);
+ auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1);
synth.popScope();
if (!assignStmt)
return nullptr;
@@ -5986,17 +5983,19 @@ namespace Slang
if (!arrayType) return;
// Explicit element count given?
- auto elementCount = arrayType->arrayLength;
- if (elementCount) return;
+ if (!isUnsizedArrayType(arrayType))
+ return;
// No initializer?
auto initExpr = varDecl->initExpr;
if(!initExpr) return;
+ IntVal* elementCount = nullptr;
+
// Is the type of the initializer an array type?
if(auto arrayInitType = as<ArrayExpressionType>(initExpr->type))
{
- elementCount = arrayInitType->arrayLength;
+ elementCount = arrayInitType->getElementCount();
}
else
{
@@ -6008,7 +6007,7 @@ namespace Slang
// and install it into our type.
varDecl->type.type = getArrayType(
m_astBuilder,
- arrayType->baseType,
+ arrayType->getElementType(),
elementCount);
}
@@ -6017,8 +6016,7 @@ namespace Slang
auto arrayType = as<ArrayExpressionType>(varDecl->type);
if (!arrayType) return;
- auto elementCount = arrayType->arrayLength;
- if (!elementCount)
+ if (arrayType->isUnsized())
{
// Note(tfoley): For now we allow arrays of unspecified size
// everywhere, because some source languages (e.g., GLSL)
@@ -6030,6 +6028,7 @@ namespace Slang
}
// TODO(tfoley): How to handle the case where bound isn't known?
+ auto elementCount = arrayType->getElementCount();
if (GetMinBound(elementCount) <= 0)
{
getSink()->diagnose(varDecl, Diagnostics::invalidArraySize);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2853c1eb9..d99114e4f 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -853,11 +853,11 @@ namespace Slang
}
else if (auto arrayType = as<ArrayExpressionType>(type))
{
- auto baseDiffType = tryGetDifferentialType(builder, arrayType->baseType);
+ auto baseDiffType = tryGetDifferentialType(builder, arrayType->getElementType());
if (!baseDiffType) return nullptr;
return builder->getArrayType(
baseDiffType,
- arrayType->arrayLength);
+ arrayType->getElementCount());
}
if (auto declRefType = as<DeclRefType>(type))
@@ -946,8 +946,8 @@ namespace Slang
if (auto arrayType = as<ArrayExpressionType>(type))
{
- maybeRegisterDifferentiableType(builder, arrayType->baseType);
- return;
+ maybeRegisterDifferentiableType(builder, arrayType->getElementType());
+ // Fall through to register the array type itself.
}
if (auto declRefType = as<DeclRefType>(type))
@@ -990,8 +990,8 @@ namespace Slang
if (auto arrayType = as<ArrayExpressionType>(type))
{
- maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet);
- return;
+ maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet);
+ // Fall through to register the array type itself.
}
if (auto declRefType = as<DeclRefType>(type))
@@ -1204,7 +1204,7 @@ namespace Slang
IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr)
{
- return m_astBuilder->getOrCreate<ConstantIntVal>(expr->type.type, expr->value);
+ return m_astBuilder->getIntVal(expr->type.type, expr->value);
}
IntVal* SemanticsVisitor::tryConstantFoldExpr(
@@ -1433,7 +1433,7 @@ namespace Slang
}
}
- IntVal* result = m_astBuilder->getOrCreate<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue);
+ IntVal* result = m_astBuilder->getIntVal(invokeExpr.getExpr()->type.type, resultValue);
return result;
}
@@ -1517,7 +1517,7 @@ namespace Slang
{
// If it's a boolean, we allow promotion to int.
const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value);
- return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value);
+ return m_astBuilder->getIntVal(m_astBuilder->getBoolType(), value);
}
if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>())
@@ -1527,8 +1527,11 @@ namespace Slang
auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts());
if (auto arrayType = as<ArrayExpressionType>(type))
{
- if (auto val = as<IntVal>(arrayType->arrayLength))
- return val;
+ if (!arrayType->isUnsized())
+ {
+ if (auto val = as<IntVal>(arrayType->getElementCount()))
+ return val;
+ }
}
}
}
@@ -1734,7 +1737,7 @@ namespace Slang
{
return CheckSimpleSubscriptExpr(
subscriptExpr,
- baseArrayType->baseType);
+ baseArrayType->getElementType());
}
else if (auto vecType = as<VectorExpressionType>(baseType))
{
@@ -2146,12 +2149,14 @@ namespace Slang
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = m_astBuilder->getDifferentiableInterface();
-
- auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface));
+
+ SubtypeWitness* conformanceWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
if (conformanceWitness)
+ {
return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ }
else
return primalType;
}
@@ -2200,15 +2205,24 @@ namespace Slang
for (UInt i = 0; i < originalType->getParamCount(); i++)
{
- if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ if (auto outType = as<OutType>(originalType->getParamType(i)))
{
- // Using inout type on all the derivative parameters
- if (auto outType = as<OutType>(derivType))
+ auto diffElementType =
+ tryGetDifferentialType(m_astBuilder, outType->getValueType());
+ if (diffElementType)
+ {
+ type->paramTypes.add(diffElementType);
+ }
+ else
{
- derivType = outType->getValueType();
+ continue;
}
- else if (as<DifferentialPairType>(derivType))
+ }
+ else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ {
+ if (as<DifferentialPairType>(derivType))
{
+ // Using inout type on all the derivative parameters
derivType = m_astBuilder->getInOutType(derivType);
}
type->paramTypes.add(derivType);
@@ -2216,7 +2230,9 @@ namespace Slang
}
// Last parameter is the initial derivative of the original return type
- type->paramTypes.add(getDifferentialType(m_astBuilder, originalType->resultType, SourceLoc()));
+ auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType);
+ if (dOutType)
+ type->paramTypes.add(dOutType);
return type;
}
@@ -2407,7 +2423,7 @@ namespace Slang
if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type))
{
expr->type = m_astBuilder->getIntType();
- if (!arrType->arrayLength)
+ if (arrType->isUnsized())
{
getSink()->diagnose(expr, Diagnostics::invalidArraySize);
}
@@ -2823,7 +2839,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
@@ -2948,7 +2964,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 1b59094e2..ac5fc8392 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2038,4 +2038,6 @@ namespace Slang
void checkModule(ModuleDecl* programNode);
};
+
+ bool isUnsizedArrayType(Type* type);
}
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 3a64f3c8f..46e39e4c0 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -62,7 +62,7 @@ namespace Slang
//
while( auto arrayType = as<ArrayExpressionType>(type) )
{
- type = arrayType->baseType;
+ type = arrayType->getElementType();
}
if( auto parameterGroupType = as<ParameterGroupType>(type) )
@@ -1125,7 +1125,7 @@ namespace Slang
if(!intVal)
{
sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl);
- intVal = getLinkage()->getASTBuilder()->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), 0);
+ intVal = getLinkage()->getASTBuilder()->getIntVal(m_astBuilder->getIntType(), 0);
}
ModuleSpecializationInfo::GenericArgInfo expandedArg;
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 0f450340f..8049c1230 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -206,8 +206,7 @@ namespace Slang
}
else
{
- ConstantIntVal* rangeBeginConst = m_astBuilder->getOrCreate<ConstantIntVal>();
- rangeBeginConst->value = 0;
+ ConstantIntVal* rangeBeginConst = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 0);
rangeBeginVal = rangeBeginConst;
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 05a5f8f56..91374e006 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -194,27 +194,19 @@ IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* b
return table;
}
-IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
+IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType)
{
- return builder->getDifferentialPairType(
- (IRType*)primalType,
- witness);
-}
-
-IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType)
-{
- auto primalType = lookupPrimalInst(builder, originalType, nullptr);
- SLANG_RELEASE_ASSERT(primalType);
-
- IRInst* witness =
+ IRInst* witness =
differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType);
if (witness)
{
witness = lookupPrimalInst(builder, witness, nullptr);
- SLANG_RELEASE_ASSERT(witness);
+ SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(originalType));
}
if (!witness)
{
+ auto primalType = lookupPrimalInst(builder, originalType, nullptr);
+ SLANG_RELEASE_ASSERT(primalType);
if (auto primalPairType = as<IRDifferentialPairType>(primalType))
{
witness = getDifferentialPairWitness(builder, originalType, primalPairType);
@@ -224,6 +216,23 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI
differentiateExtractExistentialType(builder, extractExistential, witness);
}
}
+ return witness;
+}
+
+IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
+{
+ return builder->getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+}
+
+IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType)
+{
+ auto primalType = lookupPrimalInst(builder, originalType, nullptr);
+ SLANG_RELEASE_ASSERT(primalType);
+
+ IRInst* witness = tryGetDifferentiableWitness(builder, originalType);
+ SLANG_RELEASE_ASSERT(witness);
return builder->getDifferentialPairType(
(IRType*)primalType,
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index a870dc815..e6a525dee 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -101,6 +101,8 @@ struct AutoDiffTranscriberBase
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType);
+ IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);
+
IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index ae8359251..8952f9756 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -534,31 +534,7 @@ void stripNoDiffTypeAttribute(IRModule* module)
bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
{
- HashSet<IRInst*> processedSet;
- for (;typeInst;)
- {
- if (as<IRArrayTypeBase>(typeInst) || as<IRPtrTypeBase>(typeInst))
- {
- typeInst = typeInst->getOperand(0);
- if (!processedSet.Add(typeInst))
- return false;
- }
- else
- {
- break;
- }
- }
- if (!typeInst)
- return false;
- switch (typeInst->getOp())
- {
- case kIROp_FloatType:
- case kIROp_DifferentialPairType:
- return true;
- default:
- break;
- }
- if (context.lookUpConformanceForType(typeInst))
+ if (context.isDifferentiableType((IRType*)typeInst))
return true;
// Look for equivalent types.
for (auto type : context.differentiableWitnessDictionary)
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 6da4ea6a6..2258ff753 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -190,6 +190,30 @@ struct DifferentiableTypeConformanceContext
}
}
+ bool isDifferentiableType(IRType* origType)
+ {
+ for (; origType;)
+ {
+ switch (origType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ return true;
+ case kIROp_VectorType:
+ case kIROp_ArrayType:
+ case kIROp_PtrType:
+ case kIROp_OutType:
+ case kIROp_InOutType:
+ origType = (IRType*)origType->getOperand(0);
+ continue;
+ default:
+ return lookUpConformanceForType(origType) != nullptr;
+ }
+ }
+ return false;
+ }
+
IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
{
auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 605ac62db..149f5f6b9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1443,11 +1443,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(diff);
}
- LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*)
- {
- return LoweredValInfo();
- }
-
LoweredValInfo visitTaggedUnionSubtypeWitness(
TaggedUnionSubtypeWitness* val)
{
@@ -1861,10 +1856,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
IRType* visitArrayExpressionType(ArrayExpressionType* type)
{
- auto elementType = lowerType(context, type->baseType);
- if (type->arrayLength)
+ auto elementType = lowerType(context, type->getElementType());
+ if (!type->isUnsized())
{
- auto elementCount = lowerSimpleVal(context, type->arrayLength);
+ auto elementCount = lowerSimpleVal(context, type->getElementCount());
return getBuilder()->getArrayType(
elementType,
elementCount);
@@ -3390,18 +3385,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
else if (auto arrayType = as<ArrayExpressionType>(type))
{
- UInt elementCount = (UInt) getIntVal(arrayType->arrayLength);
-
- auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->baseType));
-
- List<IRInst*> args;
- for(UInt ee = 0; ee < elementCount; ++ee)
- {
- args.add(irDefaultElement);
- }
+ auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->getElementType()));
return LoweredValInfo::simple(
- getBuilder()->emitMakeArray(irType, args.getCount(), args.getBuffer()));
+ getBuilder()->emitMakeArrayFromElement(irType, irDefaultElement));
}
else if (auto declRefType = as<DeclRefType>(type))
{
@@ -3470,7 +3457,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// fill in the appropriate field of the result
if (auto arrayType = as<ArrayExpressionType>(type))
{
- UInt elementCount = (UInt) getIntVal(arrayType->arrayLength);
+ UInt elementCount = (UInt) getIntVal(arrayType->getElementCount());
for (UInt ee = 0; ee < argCount; ++ee)
{
@@ -3480,7 +3467,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
if(elementCount > argCount)
{
- auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->baseType));
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->getElementType()));
for(UInt ee = argCount; ee < elementCount; ++ee)
{
args.add(irDefaultValue);
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index 58d6aaae3..da5099934 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -217,8 +217,8 @@ namespace Slang
else if (auto arrType = dynamicCast<ArrayExpressionType>(type))
{
emitRaw(context, "a");
- emitSimpleIntVal(context, arrType->arrayLength);
- emitType(context, arrType->baseType);
+ emitSimpleIntVal(context, arrType->getElementCount());
+ emitType(context, arrType->getElementType());
}
else if( auto taggedUnionType = dynamicCast<TaggedUnionType>(type) )
{
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index db323ff6e..f3eabf613 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -1889,16 +1889,18 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
// Note: Bad Things will happen if we have an array input
// without a semantic already being enforced.
- auto elementCount = (UInt) getIntVal(arrayType->arrayLength);
+ auto elementCount = (UInt) getIntVal(arrayType->getElementCount());
+ if (arrayType->isUnsized())
+ elementCount = 0;
// We use the first element to derive the layout for the element type
- auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->baseType, state, varLayout);
+ auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->getElementType(), state, varLayout);
// We still walk over subsequent elements to make sure they consume resources
// as needed
for( UInt ii = 1; ii < elementCount; ++ii )
{
- processEntryPointVaryingParameter(context, arrayType->baseType, state, nullptr);
+ processEntryPointVaryingParameter(context, arrayType->getElementType(), state, nullptr);
}
RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout();
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 5c5773fec..9c1d48a28 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -149,7 +149,7 @@ static SlangParameterCategory maybeRemapParameterCategory(
// of this variable?
Type* type = typeLayout->getType();
while (auto arrayType = as<ArrayExpressionType>(type))
- type = arrayType->baseType;
+ type = arrayType->getElementType();
switch (spReflectionType_GetKind(convert(type)))
{
case SLANG_TYPE_KIND_CONSTANT_BUFFER:
@@ -462,7 +462,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType)
if(auto arrayType = as<ArrayExpressionType>(type))
{
- return arrayType->arrayLength ? (size_t) getIntVal(arrayType->arrayLength) : 0;
+ return !arrayType->isUnsized() ? (size_t)getIntVal(arrayType->getElementCount()) : 0;
}
else if( auto vectorType = as<VectorExpressionType>(type))
{
@@ -479,7 +479,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy
if(auto arrayType = as<ArrayExpressionType>(type))
{
- return (SlangReflectionType*) arrayType->baseType;
+ return (SlangReflectionType*) arrayType->getElementType();
}
else if( auto parameterGroupType = as<ParameterGroupType>(type))
{
@@ -631,7 +631,7 @@ SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionTy
while(auto arrayType = as<ArrayExpressionType>(type))
{
- type = arrayType->baseType;
+ type = arrayType->getElementType();
}
if(auto textureType = as<TextureTypeBase>(type))
@@ -667,7 +667,7 @@ SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflection
while(auto arrayType = as<ArrayExpressionType>(type))
{
- type = arrayType->baseType;
+ type = arrayType->getElementType();
}
if(auto textureType = as<TextureTypeBase>(type))
@@ -763,7 +763,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangRefle
while(auto arrayType = as<ArrayExpressionType>(type))
{
- type = arrayType->baseType;
+ type = arrayType->getElementType();
}
if (auto textureType = as<TextureTypeBase>(type))
@@ -1492,9 +1492,9 @@ namespace Slang
LayoutSize elementCount = LayoutSize::infinite();
if( auto arrayType = as<ArrayExpressionType>(arrayTypeLayout->type) )
{
- if( auto elementCountVal = arrayType->arrayLength )
+ if( !arrayType->isUnsized())
{
- elementCount = LayoutSize::RawValue(getIntVal(elementCountVal));
+ elementCount = LayoutSize::RawValue(getIntVal(arrayType->getElementCount()));
}
}
addRangesRec(elementTypeLayout, path, multiplier * elementCount);
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 4e5db17c0..6076a41ca 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -513,6 +513,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]);
return vecType;
}
+ else if (magicMod->magicName == "ArrayType")
+ {
+ SLANG_ASSERT(subst && subst->getArgs().getCount() == 2);
+ auto vecType = astBuilder->getOrCreate<ArrayExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1]));
+ vecType->declRef = declRef;
+ return vecType;
+ }
else if (magicMod->magicName == "Matrix")
{
SLANG_ASSERT(subst && subst->getArgs().getCount() == 3);
@@ -1097,19 +1104,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
Type* elementType,
IntVal* elementCount)
{
- auto arrayType = astBuilder->create<ArrayExpressionType>();
- arrayType->baseType = elementType;
- arrayType->arrayLength = elementCount;
- return arrayType;
+ return astBuilder->getArrayType(elementType, elementCount);
}
ArrayExpressionType* getArrayType(
ASTBuilder* astBuilder,
Type* elementType)
{
- auto arrayType = astBuilder->create<ArrayExpressionType>();
- arrayType->baseType = elementType;
- return arrayType;
+ return astBuilder->getArrayType(elementType, nullptr);
}
NamedExpressionType* getNamedType(
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index e36ee944c..dd119cf3a 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -326,6 +326,8 @@ namespace Slang
All = 7
};
+ const int kUnsizedArrayMagicLength = 0x7FFFFFFF;
+
/// Get the module dclaration that a declaration is associated with, if any.
ModuleDecl* getModuleDecl(Decl* decl);
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 45bacaa44..3da0cc95a 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -1467,6 +1467,8 @@ static LayoutSize GetElementCount(IntVal* val)
if (auto constantVal = as<ConstantIntVal>(val))
{
+ if (constantVal->value == kUnsizedArrayMagicLength)
+ return LayoutSize::infinite();
return LayoutSize(LayoutSize::RawValue(constantVal->value));
}
else if( auto varRefVal = as<GenericParamIntVal>(val) )
@@ -3669,7 +3671,7 @@ static TypeLayoutResult _createTypeLayout(
}
else if (auto arrayType = as<ArrayExpressionType>(type))
{
- return createArrayLikeTypeLayout(context, arrayType, arrayType->baseType, arrayType->arrayLength);
+ return createArrayLikeTypeLayout(context, arrayType, arrayType->getElementType(), arrayType->getElementCount());
}
else if (auto declRefType = as<DeclRefType>(type))
{