summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-builder.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ast-builder.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
-rw-r--r--source/slang/slang-ast-builder.cpp202
1 files changed, 113 insertions, 89 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 7edbe750a..575c7268b 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -1,16 +1,16 @@
// slang-ast-builder.cpp
#include "slang-ast-builder.h"
-#include <assert.h>
#include "slang-compiler.h"
-namespace Slang {
+#include <assert.h>
+
+namespace Slang
+{
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SharedASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-SharedASTBuilder::SharedASTBuilder()
-{
-}
+SharedASTBuilder::SharedASTBuilder() {}
void SharedASTBuilder::init(Session* session)
{
@@ -90,7 +90,8 @@ Type* SharedASTBuilder::getNativeStringType()
if (!m_nativeStringType)
{
auto nativeStringTypeDecl = findMagicDecl("NativeStringType");
- m_nativeStringType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(nativeStringTypeDecl));
+ m_nativeStringType =
+ DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(nativeStringTypeDecl));
}
return m_nativeStringType;
}
@@ -190,7 +191,9 @@ void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modi
m_builtinTypes[Index(modifier->tag)] = type;
}
-void SharedASTBuilder::registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier)
+void SharedASTBuilder::registerBuiltinRequirementDecl(
+ Decl* decl,
+ BuiltinRequirementModifier* modifier)
{
m_builtinRequirementDecls[modifier->kind] = decl;
}
@@ -221,11 +224,11 @@ Decl* SharedASTBuilder::tryFindMagicDecl(const String& name)
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name):
- m_sharedASTBuilder(sharedASTBuilder),
- m_name(name),
- m_id(sharedASTBuilder->m_id++),
- m_arena(2097152)
+ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name)
+ : m_sharedASTBuilder(sharedASTBuilder)
+ , m_name(name)
+ , m_id(sharedASTBuilder->m_id++)
+ , m_arena(2097152)
{
SLANG_ASSERT(sharedASTBuilder);
// Copy Val deduplication map over so we don't create duplicate Vals that are already
@@ -233,10 +236,8 @@ ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name):
m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes;
}
-ASTBuilder::ASTBuilder():
- m_sharedASTBuilder(nullptr),
- m_id(-1),
- m_arena(2097152)
+ASTBuilder::ASTBuilder()
+ : m_sharedASTBuilder(nullptr), m_id(-1), m_arena(2097152)
{
m_name = "SharedASTBuilder::m_astBuilder";
}
@@ -265,7 +266,7 @@ void ASTBuilder::incrementEpoch()
NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
{
const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType);
-
+
auto createFunc = info->m_createFunc;
SLANG_ASSERT(createFunc);
if (!createFunc)
@@ -327,9 +328,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
return as<PtrTypeBase>(getSpecializedBuiltinType(valueType, ptrTypeName));
}
-PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName)
+PtrTypeBase* ASTBuilder::getPtrType(
+ Type* valueType,
+ AddressSpace addrSpace,
+ char const* ptrTypeName)
{
- Val* args[] = { valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace) };
+ Val* args[] = {valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace)};
return as<PtrTypeBase>(getSpecializedBuiltinType(makeArrayView(args), ptrTypeName));
}
@@ -350,7 +354,8 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element
}
}
Val* args[] = {elementType, elementCount};
- return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType"));
+ return as<ArrayExpressionType>(
+ getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType"));
}
ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType)
@@ -365,12 +370,14 @@ ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType)
HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType)
{
- return as<HLSLStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType"));
+ return as<HLSLStructuredBufferType>(
+ getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType"));
}
HLSLRWStructuredBufferType* ASTBuilder::getRWStructuredBufferType(Type* elementType)
{
- return as<HLSLRWStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLRWStructuredBufferType"));
+ return as<HLSLRWStructuredBufferType>(
+ getSpecializedBuiltinType(elementType, "HLSLRWStructuredBufferType"));
}
SamplerStateType* ASTBuilder::getSamplerStateType()
@@ -378,20 +385,23 @@ SamplerStateType* ASTBuilder::getSamplerStateType()
return as<SamplerStateType>(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType"));
}
-VectorExpressionType* ASTBuilder::getVectorType(
- Type* elementType,
- IntVal* elementCount)
+VectorExpressionType* ASTBuilder::getVectorType(Type* elementType, IntVal* elementCount)
{
// Canonicalize constant elementCount to int.
if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
{
elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
}
- Val* args[] = { elementType, elementCount };
- return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType"));
+ Val* args[] = {elementType, elementCount};
+ return as<VectorExpressionType>(
+ getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType"));
}
-MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout)
+MatrixExpressionType* ASTBuilder::getMatrixType(
+ Type* elementType,
+ IntVal* rowCount,
+ IntVal* colCount,
+ IntVal* layout)
{
// Canonicalize constant size arguments to int.
if (auto rowCountConstantInt = as<ConstantIntVal>(rowCount))
@@ -402,35 +412,38 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo
{
colCount = getIntVal(getIntType(), colCountConstantInt->getValue());
}
- Val* args[] = { elementType, rowCount, colCount, layout };
- return as<MatrixExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType"));
+ Val* args[] = {elementType, rowCount, colCount, layout};
+ return as<MatrixExpressionType>(
+ getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType"));
}
-DifferentialPairType* ASTBuilder::getDifferentialPairType(
- Type* valueType,
- Witness* diffTypeWitness)
+DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* diffTypeWitness)
{
- Val* args[] = { valueType, diffTypeWitness };
- return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
+ Val* args[] = {valueType, diffTypeWitness};
+ return as<DifferentialPairType>(
+ getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}
DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType(
Type* valueType,
Witness* diffRefTypeWitness)
{
- Val* args[] = { valueType, diffRefTypeWitness };
- return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType"));
+ Val* args[] = {valueType, diffRefTypeWitness};
+ return as<DifferentialPtrPairType>(
+ getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType"));
}
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
{
- DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
+ DeclRef<InterfaceDecl> declRef =
+ DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
return declRef;
}
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl()
{
- DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr));
+ DeclRef<InterfaceDecl> declRef =
+ DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr));
return declRef;
}
@@ -441,12 +454,15 @@ bool ASTBuilder::isDifferentiableInterfaceAvailable()
DeclRef<InterfaceDecl> ASTBuilder::getDefaultInitializableTypeInterfaceDecl()
{
- DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DefaultInitializableType", nullptr));
+ DeclRef<InterfaceDecl> declRef =
+ DeclRef<InterfaceDecl>(getBuiltinDeclRef("DefaultInitializableType", nullptr));
return declRef;
}
Type* ASTBuilder::getDefaultInitializableType()
{
- return DeclRefType::create(m_sharedASTBuilder->m_astBuilder, getDefaultInitializableTypeInterfaceDecl());
+ return DeclRefType::create(
+ m_sharedASTBuilder->m_astBuilder,
+ getDefaultInitializableTypeInterfaceDecl());
}
MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier(
@@ -458,13 +474,13 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier(
SLANG_ASSERT(elementType);
SLANG_ASSERT(maxElementCount);
- const char* declName
- = as<HLSLVerticesModifier>(modifier) ? "VerticesType"
- : as<HLSLIndicesModifier>(modifier) ? "IndicesType"
- : as<HLSLPrimitivesModifier>(modifier) ? "PrimitivesType"
- : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr);
+ const char* declName = as<HLSLVerticesModifier>(modifier) ? "VerticesType"
+ : as<HLSLIndicesModifier>(modifier) ? "IndicesType"
+ : as<HLSLPrimitivesModifier>(modifier)
+ ? "PrimitivesType"
+ : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr);
- Val* args[] = { elementType, maxElementCount };
+ Val* args[] = {elementType, maxElementCount};
return as<MeshOutputType>(getSpecializedBuiltinType(makeArrayView(args), declName));
}
@@ -483,7 +499,8 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
if (auto genericDecl = as<GenericDecl>(decl))
{
- auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg));
+ auto declRef =
+ getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg));
return declRef;
}
else
@@ -493,7 +510,9 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va
return makeDeclRef(decl);
}
-DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs)
+DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(
+ const char* builtinMagicTypeName,
+ ArrayView<Val*> genericArgs)
{
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
if (auto genericDecl = as<GenericDecl>(decl))
@@ -544,8 +563,9 @@ FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Typ
TupleType* ASTBuilder::getTupleType(ArrayView<Type*> types)
{
- // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, ConcreteTypePack(types...))).
- // If `types` is already a single ConcreteTypePack, then we can use that directly.
+ // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl,
+ // ConcreteTypePack(types...))). If `types` is already a single ConcreteTypePack, then we can
+ // use that directly.
if (types.getCount() == 1)
{
if (isTypePack(types[0]))
@@ -572,7 +592,8 @@ Type* ASTBuilder::getEachType(Type* baseType)
return expandType->getPatternType();
}
- // each Tuple<X> ==> each X, because we know that Tuple type must be in the form of Tuple<ConcreteTypePack<...>>.
+ // each Tuple<X> ==> each X, because we know that Tuple type must be in the form of
+ // Tuple<ConcreteTypePack<...>>.
if (auto tupleType = as<TupleType>(baseType))
{
return getEachType(tupleType->getTypePack());
@@ -613,20 +634,23 @@ ConcreteTypePack* ASTBuilder::getTypePack(ArrayView<Type*> types)
return getOrCreate<ConcreteTypePack>(flattenedTypes.getArrayView().arrayView);
}
-TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(
- Type* type)
+TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(Type* type)
{
return getOrCreate<TypeEqualityWitness>(type, type);
}
TypePackSubtypeWitness* ASTBuilder::getSubtypeWitnessPack(
- Type* subType, Type* superType, ArrayView<SubtypeWitness*> witnesses)
+ Type* subType,
+ Type* superType,
+ ArrayView<SubtypeWitness*> witnesses)
{
return getOrCreate<TypePackSubtypeWitness>(subType, superType, witnesses);
}
SubtypeWitness* ASTBuilder::getExpandSubtypeWitness(
- Type* subType, Type* superType, SubtypeWitness* patternWitness)
+ Type* subType,
+ Type* superType,
+ SubtypeWitness* patternWitness)
{
if (auto eachWitness = as<EachSubtypeWitness>(patternWitness))
return eachWitness->getPatternTypeWitness();
@@ -634,7 +658,9 @@ SubtypeWitness* ASTBuilder::getExpandSubtypeWitness(
}
SubtypeWitness* ASTBuilder::getEachSubtypeWitness(
- Type* subType, Type* superType, SubtypeWitness* patternWitness)
+ Type* subType,
+ Type* superType,
+ SubtypeWitness* patternWitness)
{
if (auto expandWitness = as<ExpandSubtypeWitness>(patternWitness))
return expandWitness->getPatternTypeWitness();
@@ -642,12 +668,11 @@ SubtypeWitness* ASTBuilder::getEachSubtypeWitness(
}
DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness(
- Type* subType,
- Type* superType,
- DeclRef<Decl> const& declRef)
+ Type* subType,
+ Type* superType,
+ DeclRef<Decl> const& declRef)
{
- auto witness = getOrCreate<DeclaredSubtypeWitness>(
- subType, superType, declRef.declRefBase);
+ auto witness = getOrCreate<DeclaredSubtypeWitness>(subType, superType, declRef.declRefBase);
return witness;
}
@@ -666,7 +691,7 @@ top:
//
// If `a == b`, then the `b <: c` witness is also a witness of `a <: c`.
//
- if(as<TypeEqualityWitness>(aIsSubtypeOfBWitness))
+ if (as<TypeEqualityWitness>(aIsSubtypeOfBWitness))
{
return bIsSubtypeOfCWitness;
}
@@ -694,9 +719,8 @@ top:
// We can recursively call this operation to produce a witness that
// `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`:
//
- auto aIsSubtypeOfXWitness = getTransitiveSubtypeWitness(
- aIsSubtypeOfBWitness,
- bIsSubtypeOfXWitness);
+ auto aIsSubtypeOfXWitness =
+ getTransitiveSubtypeWitness(aIsSubtypeOfBWitness, bIsSubtypeOfXWitness);
// Now we can perform a "tail recursive" call to this function (via `goto`
// to combine the `a <: x` witness with our `x <: c` witness:
@@ -714,7 +738,7 @@ top:
// and we'd rather form a conjunction witness for `A <: C`
// that is of the form `(A <: X)&(A <: Y)`.
//
- if(auto bIsSubtypeOfXAndY = as<ConjunctionSubtypeWitness>(bIsSubtypeOfCWitness))
+ if (auto bIsSubtypeOfXAndY = as<ConjunctionSubtypeWitness>(bIsSubtypeOfCWitness))
{
auto bIsSubtypeOfXWitness = bIsSubtypeOfXAndY->getLeftWitness();
auto bIsSubtypeOfYWitness = bIsSubtypeOfXAndY->getRightWitness();
@@ -730,7 +754,8 @@ top:
// `W` is a witness that `B <: X&Y&...` for some conjunction, where `C`
// is one component of that conjunction.
//
- if(auto bIsSubtypeViaExtraction = as<ExtractFromConjunctionSubtypeWitness>(bIsSubtypeOfCWitness))
+ if (auto bIsSubtypeViaExtraction =
+ as<ExtractFromConjunctionSubtypeWitness>(bIsSubtypeOfCWitness))
{
// We decompose the witness `extract(i, W)` to get both
// the witness `W` that `B <: X&Y&...` as well as the index
@@ -761,12 +786,10 @@ top:
List<SubtypeWitness*> newWitnesses;
for (Index i = 0; i < witnessPack->getCount(); i++)
{
- newWitnesses.add(getTransitiveSubtypeWitness(witnessPack->getWitness(i), bIsSubtypeOfCWitness));
+ newWitnesses.add(
+ getTransitiveSubtypeWitness(witnessPack->getWitness(i), bIsSubtypeOfCWitness));
}
- return getSubtypeWitnessPack(
- aType,
- cType,
- newWitnesses.getArrayView());
+ return getSubtypeWitnessPack(aType, cType, newWitnesses.getArrayView());
}
// If left hand is a ExpandSubtypeWitness, then we want to perform the transitive lookup
@@ -774,7 +797,9 @@ top:
//
if (auto expandWitness = as<ExpandSubtypeWitness>(aIsSubtypeOfBWitness))
{
- auto innerTransitiveWitness = getTransitiveSubtypeWitness(expandWitness->getPatternTypeWitness(), bIsSubtypeOfCWitness);
+ auto innerTransitiveWitness = getTransitiveSubtypeWitness(
+ expandWitness->getPatternTypeWitness(),
+ bIsSubtypeOfCWitness);
return getExpandSubtypeWitness(expandWitness->getSub(), cType, innerTransitiveWitness);
}
@@ -787,8 +812,12 @@ top:
{
if (declRefType->getDeclRef().as<GenericTypePackParamDecl>())
{
- auto newLeftHandWitness = getEachSubtypeWitness(getEachType(declaredWitness->getSub()), declaredWitness->getSup(), declaredWitness);
- auto transitiveWitness = getTransitiveSubtypeWitness(newLeftHandWitness, bIsSubtypeOfCWitness);
+ auto newLeftHandWitness = getEachSubtypeWitness(
+ getEachType(declaredWitness->getSub()),
+ declaredWitness->getSup(),
+ declaredWitness);
+ auto transitiveWitness =
+ getTransitiveSubtypeWitness(newLeftHandWitness, bIsSubtypeOfCWitness);
return getExpandSubtypeWitness(aType, cType, transitiveWitness);
}
}
@@ -813,10 +842,10 @@ top:
}
SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
- Type* subType,
- Type* superType,
+ Type* subType,
+ Type* superType,
SubtypeWitness* conjunctionWitness,
- int indexOfSuperTypeInConjunction)
+ int indexOfSuperTypeInConjunction)
{
// We are taking a witness `W` for `S <: L&R` and
// using it to produce a witness for `S <: L`
@@ -845,8 +874,8 @@ SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
}
SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
- Type* sub,
- Type* lAndR,
+ Type* sub,
+ Type* lAndR,
SubtypeWitness* subIsLWitness,
SubtypeWitness* subIsRWitness)
{
@@ -858,10 +887,9 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
//
auto lExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsLWitness);
auto rExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsRWitness);
- if(lExtract && rExtract)
+ if (lExtract && rExtract)
{
- if (lExtract->getIndexInConjunction() == 0
- && rExtract->getIndexInConjunction() == 1)
+ if (lExtract->getIndexInConjunction() == 0 && rExtract->getIndexInConjunction() == 1)
{
auto lInner = lExtract->getConjunctionWitness();
auto rInner = rExtract->getConjunctionWitness();
@@ -883,11 +911,7 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
// witness) deeper, so that we have more chances to expose a
// conjunction witness at higher levels.
- auto witness = getOrCreate<ConjunctionSubtypeWitness>(
- sub,
- lAndR,
- subIsLWitness,
- subIsRWitness);
+ auto witness = getOrCreate<ConjunctionSubtypeWitness>(sub, lAndR, subIsLWitness, subIsRWitness);
return witness;
}