// slang-ast-type.cpp #include "slang-ast-type.h" #include "slang-ast-builder.h" #include "slang-ast-dispatch.h" #include "slang-ast-modifier.h" #include "slang-syntax.h" #include #include namespace Slang { bool isAbstractTypePack(Type* type) { if (as(type)) return true; if (isDeclRefTypeOf(type)) return true; return false; } bool isTypePack(Type* type) { if (as(type)) return true; return isAbstractTypePack(type); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* Type::_createCanonicalTypeOverride() { return as(defaultResolveImpl()); } Val* Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; auto canonicalType = getCanonicalType(); // If canonicalType is identical to this, then we shouldn't try to call // canonicalType->substituteImpl because that would lead to infinite recursion. if (canonicalType == this) return this; auto canSubst = canonicalType->substituteImpl(astBuilder, subst, &diff); // If nothing changed, then don't drop any sugar that is applied if (!diff) return this; // If the canonical type changed, then we return a canonical type, // rather than try to re-construct any amount of sugar (*ioDiff)++; return canSubst; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! OverloadGroupType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void OverloadGroupType::_toTextOverride(StringBuilder& out) { out << toSlice("overload group"); } Type* OverloadGroupType::_createCanonicalTypeOverride() { return this; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! InitializerListType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void InitializerListType::_toTextOverride(StringBuilder& out) { out << toSlice("initializer list"); } Type* InitializerListType::_createCanonicalTypeOverride() { return this; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ErrorType::_toTextOverride(StringBuilder& out) { out << toSlice("error"); } Type* ErrorType::_createCanonicalTypeOverride() { return this; } Val* ErrorType::_substituteImplOverride( ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/ ) { return this; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void BottomType::_toTextOverride(StringBuilder& out) { out << toSlice("never"); } Val* BottomType::_substituteImplOverride( ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/ ) { return this; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void DeclRefType::_toTextOverride(StringBuilder& out) { out << getDeclRef(); } Val* maybeSubstituteGenericParam( Val* paramVal, Decl* paramDecl, SubstitutionSet subst, int* ioDiff); Val* DeclRefType::_substituteImplOverride( ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { if (!subst) return this; int diff = 0; DeclRef substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); // If this declref type is a direct reference to ThisType or a Generic parameter, // and `subst` provides an argument for it, then we should just return that argument. // if (as(substDeclRef.declRefBase) || as(substDeclRef.declRefBase)) { if (as(substDeclRef.getDecl())) { auto lookupDeclRef = subst.findLookupDeclRef(); if (lookupDeclRef && lookupDeclRef->getSupDecl() == substDeclRef.getDecl()->parentDecl) { (*ioDiff)++; return lookupDeclRef->getLookupSource(); } } else if ( as(substDeclRef.getDecl()) || as(substDeclRef.getDecl())) { auto resultVal = maybeSubstituteGenericParam(nullptr, substDeclRef.getDecl(), subst, ioDiff); if (resultVal) { (*ioDiff)++; return resultVal; } } } // If this type is a reference to an associated type declaration, // and the substitutions provide a "this type" substitution for // the outer interface, then try to replace the type with the // actual value of the associated type for the given implementation. // if (auto satisfyingVal = substDeclRef.declRefBase->resolve()) { if (satisfyingVal != getDeclRef()) { *ioDiff += 1; return DeclRefType::create(astBuilder, substDeclRef); } } if (!diff) return this; // Make sure to record the difference! *ioDiff += diff; // Re-construct the type in case we are using a specialized sub-class return DeclRefType::create(astBuilder, substDeclRef); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArithmeticExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType* ArithmeticExpressionType::getScalarType(){ SLANG_AST_NODE_VIRTUAL_CALL(ArithmeticExpressionType, getScalarType, ())} BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride() { SLANG_UNEXPECTED("ArithmeticExpressionType::_getScalarTypeOverride not overridden"); // return nullptr; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() { return this; } static Val* _getGenericTypeArg(DeclRefBase* declRef, Index i) { auto args = findInnerMostGenericArgs(SubstitutionSet(declRef)); if (args.getCount() <= i) return nullptr; return args[i]; } static Val* _getGenericTypeArg(DeclRefType* declRefType, Index i) { return _getGenericTypeArg(declRefType->getDeclRefBase(), i); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* TensorViewType::getElementType() { return as(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* VectorExpressionType::getElementType() { return as(_getGenericTypeArg(this, 0)); } IntVal* VectorExpressionType::getElementCount() { return as(_getGenericTypeArg(this, 1)); } void VectorExpressionType::_toTextOverride(StringBuilder& out) { out << toSlice("vector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">"); } BasicExpressionType* VectorExpressionType::_getScalarTypeOverride() { return as(getElementType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MatrixExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void MatrixExpressionType::_toTextOverride(StringBuilder& out) { out << toSlice("matrix<") << getElementType() << toSlice(",") << getRowCount() << toSlice(",") << getColumnCount() << toSlice(">"); } BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() { return as(getElementType()); } Type* MatrixExpressionType::getElementType() { return as(_getGenericTypeArg(this, 0)); } IntVal* MatrixExpressionType::getRowCount() { return as(_getGenericTypeArg(this, 1)); } IntVal* MatrixExpressionType::getColumnCount() { return as(_getGenericTypeArg(this, 2)); } IntVal* MatrixExpressionType::getLayout() { return as(_getGenericTypeArg(this, 3)); } Type* MatrixExpressionType::getRowType() { if (!rowType) { rowType = getCurrentASTBuilder()->getVectorType(getElementType(), getColumnCount()); } return rowType; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* TupleType::getMember(Index i) const { if (auto typePack = as(_getGenericTypeArg(getDeclRefBase(), 0))) return typePack->getElementType(i); return nullptr; } Index TupleType::getMemberCount() const { if (auto typePack = as(_getGenericTypeArg(getDeclRefBase(), 0))) return typePack->getTypeCount(); return 0; } Type* TupleType::getTypePack() const { return as(_getGenericTypeArg(getDeclRefBase(), 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* ArrayExpressionType::getElementType() { return as(_getGenericTypeArg(this, 0)); } IntVal* ArrayExpressionType::getElementCount() { return as(_getGenericTypeArg(this, 1)); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) { out << getElementType(); out.appendChar('['); if (!isUnsized()) { out << getElementCount(); } out.appendChar(']'); } bool ArrayExpressionType::isUnsized() { if (auto constSize = as(getElementCount())) { if (constSize->getValue() == kUnsizedArrayMagicLength) return true; } return false; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AtomicType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* AtomicType::getElementType() { return as(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CoopVectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* CoopVectorExpressionType::getElementType() { return as(_getGenericTypeArg(this, 0)); } IntVal* CoopVectorExpressionType::getElementCount() { return as(_getGenericTypeArg(this, 1)); } void CoopVectorExpressionType::_toTextOverride(StringBuilder& out) { out << toSlice("CoopVector<") << getElementType() << toSlice(",") << getElementCount() << toSlice(">"); } BasicExpressionType* CoopVectorExpressionType::_getScalarTypeOverride() { return as(getElementType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TypeType::_toTextOverride(StringBuilder& out) { out << toSlice("typeof(") << getType() << toSlice(")"); } Type* TypeType::_createCanonicalTypeOverride() { return getCurrentASTBuilder()->getTypeType(getType()->getCanonicalType()); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericDeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void GenericDeclRefType::_toTextOverride(StringBuilder& out) { out << getDeclRef(); } Type* GenericDeclRefType::_createCanonicalTypeOverride() { return this; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamespaceType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void NamespaceType::_toTextOverride(StringBuilder& out) { out << toSlice("namespace ") << getDeclRef(); } Type* NamespaceType::_createCanonicalTypeOverride() { return this; } Type* DifferentialPairType::getPrimalType() { return as(_getGenericTypeArg(this, 0)); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! PtrTypeBase !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type* PtrTypeBase::getValueType() { return as(_getGenericTypeArg(this, 0)); } Type* OptionalType::getValueType() { return as(_getGenericTypeArg(this, 0)); } Type* NativeRefType::getValueType() { return as(_getGenericTypeArg(this, 0)); } Val* PtrTypeBase::getAccessQualifier() { return _getGenericTypeArg(this, 1); } Val* PtrTypeBase::getAddressSpace() { return _getGenericTypeArg(this, 2); } std::optional tryGetAccessQualifierValue(Val* val) { if (auto cintVal = as(val)) { return AccessQualifier(cintVal->getValue()); } return std::optional(); } std::optional PtrTypeBase::tryGetAccessQualifierValue() { auto accessQualifierArg = this->getAccessQualifier(); return Slang::tryGetAccessQualifierValue(accessQualifierArg); } AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal) { AddressSpace addrSpace = AddressSpace::Generic; if (auto cintVal = as(addrSpaceVal)) { addrSpace = (AddressSpace)(cintVal->getValue()); } return addrSpace; } void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace) { switch (addrSpace) { case AddressSpace::Generic: out << toSlice(", AddressSpace::Generic"); break; case AddressSpace::UserPointer: // We expose UserPointer as Device to users out << toSlice(", AddressSpace::Device"); break; case AddressSpace::GroupShared: out << toSlice(", AddressSpace::GroupShared"); break; case AddressSpace::Global: out << toSlice(", AddressSpace::Global"); break; case AddressSpace::ThreadLocal: out << toSlice(", AddressSpace::ThreadLocal"); break; case AddressSpace::Uniform: out << toSlice(", AddressSpace::Uniform"); break; default: break; } } void maybePrintAccessQualifierOperand(StringBuilder& out, AccessQualifier accessQualifier) { switch (accessQualifier) { case AccessQualifier::ReadWrite: out << toSlice(", Access::ReadWrite"); break; case AccessQualifier::Read: out << toSlice(", Access::Read"); break; default: break; } } void PtrType::_toTextOverride(StringBuilder& out) { auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ptr<") << getValueType(); if (auto optionalAccessQualifier = tryGetAccessQualifierValue()) maybePrintAccessQualifierOperand(out, *optionalAccessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } void ExplicitRefType::_toTextOverride(StringBuilder& out) { auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ref<") << getValueType(); if (auto optionalAccessQualifier = tryGetAccessQualifierValue()) maybePrintAccessQualifierOperand(out, *optionalAccessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } void OutParamType::_toTextOverride(StringBuilder& out) { out << toSlice("out ") << getValueType(); } void InOutParamType::_toTextOverride(StringBuilder& out) { out << toSlice("inout ") << getValueType(); } void RefParamType::_toTextOverride(StringBuilder& out) { out << toSlice("ref ") << getValueType(); } void ConstRefParamType::_toTextOverride(StringBuilder& out) { out << toSlice("borrow ") << getValueType(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void NamedExpressionType::_toTextOverride(StringBuilder& out) { if (getDeclRef().getDecl()) { getDeclRef().declRefBase->toText(out); } } Type* NamedExpressionType::_createCanonicalTypeOverride() { auto canType = getType(getCurrentASTBuilder(), getDeclRef()); if (canType) return canType->getCanonicalType(); return getCurrentASTBuilder()->getErrorType(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ParameterDirection getParamPassingModeFromPossiblyWrappedParamType(Type* paramType) { if (as(paramType)) { return kParameterDirection_Ref; } else if (as(paramType)) { return kParameterDirection_ConstRef; } else if (as(paramType)) { return kParameterDirection_InOut; } else if (as(paramType)) { return kParameterDirection_Out; } else { return kParameterDirection_In; } } ParameterDirection FuncType::getParamDirection(Index index) { auto paramType = getParamTypeWithDirectionWrapper(index); return getParamPassingModeFromPossiblyWrappedParamType(paramType); } Type* FuncType::getParamValueType(Index index) { auto paramType = getParamTypeWithDirectionWrapper(index); if (auto wrappedParamType = as(paramType)) return wrappedParamType->getValueType(); return paramType; } void FuncType::_toTextOverride(StringBuilder& out) { Index paramCount = getParamCount(); out << toSlice("("); for (Index pp = 0; pp < paramCount; ++pp) { if (pp != 0) { out << toSlice(", "); } out << getParamTypeWithDirectionWrapper(pp); } out << ") -> " << getResultType(); if (!getErrorType()->equals(getCurrentASTBuilder()->getBottomType())) { out << " throws " << getErrorType(); } } Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // result type Type* substResultType = as(getResultType()->substituteImpl(astBuilder, subst, &diff)); // error type Type* substErrorType = as(getErrorType()->substituteImpl(astBuilder, subst, &diff)); // parameter types List substParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { auto substParamType = as( getParamTypeWithDirectionWrapper(pp)->substituteImpl(astBuilder, subst, &diff)); if (auto typePack = as(substParamType)) { // Unwrap the ConcreteTypePack and add each element as a parameter for (Index i = 0; i < typePack->getTypeCount(); ++i) { substParamTypes.add(typePack->getElementType(i)); } } else { substParamTypes.add(substParamType); } } // early exit for no change... if (!diff) return this; (*ioDiff)++; FuncType* substType = astBuilder->getFuncType(substParamTypes.getArrayView(), substResultType, substErrorType); return substType; } Type* FuncType::_createCanonicalTypeOverride() { // result type Type* canResultType = getResultType()->getCanonicalType(); Type* canErrorType = getErrorType()->getCanonicalType(); // parameter types List canParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { canParamTypes.add(getParamTypeWithDirectionWrapper(pp)->getCanonicalType()); } FuncType* canType = getCurrentASTBuilder()->getFuncType( canParamTypes.getArrayView(), canResultType, canErrorType); return canType; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! EachType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void EachType::_toTextOverride(StringBuilder& out) { out << "each "; if (getElementType()) { getElementType()->toText(out); } else { out << ""; } } Type* EachType::_createCanonicalTypeOverride() { return this; } Val* EachType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; auto substElementType = as(getElementType()->substituteImpl(astBuilder, subst, &diff)); if (!diff) return this; if (auto typePack = as(substElementType)) { if (subst.packExpansionIndex >= 0 && subst.packExpansionIndex < typePack->getTypeCount()) { (*ioDiff)++; return typePack->getElementType(subst.packExpansionIndex); } } else if (auto expandType = as(substElementType)) { if (auto innerEach = as(expandType->getPatternType())) { (*ioDiff)++; return innerEach; } } (*ioDiff)++; return astBuilder->getEachType(substElementType); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExpandType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExpandType::_toTextOverride(StringBuilder& out) { out << "expand "; getPatternType()->toText(out); } Type* ExpandType::_createCanonicalTypeOverride() { auto canonicalPatternType = getPatternType()->getCanonicalType(); if (canonicalPatternType == getPatternType()) return this; ShortList capturedPacks; for (Index i = 0; i < getCapturedTypePackCount(); i++) { capturedPacks.add(getCapturedTypePack(i)); } return getCurrentASTBuilder()->getExpandType( canonicalPatternType, capturedPacks.getArrayView().arrayView); } Val* ExpandType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; // Given ExpandType(PatternType, CapturedTypePackParams), we first need to know // if all captured GenericTypePackParams can be substituted into concrete type packs. // We can't expand the ExpandType into a concrete type pack, if any of the captured type // pack parameters aren't concrete themselves. // ShortList capturedPacks; ShortList concreteTypePacks; for (Index i = 0; i < getCapturedTypePackCount(); i++) { auto substCapturedTypePack = getCapturedTypePack(i)->substituteImpl(astBuilder, subst, &diff); if (auto expandType = as(substCapturedTypePack)) { for (Index j = 0; j < expandType->getCapturedTypePackCount(); j++) capturedPacks.add(expandType->getCapturedTypePack(j)); } else { capturedPacks.add(as(substCapturedTypePack)); if (auto pack = as(capturedPacks.getLast())) { concreteTypePacks.add(pack); } } } if (!diff || concreteTypePacks.getCount() != capturedPacks.getCount()) { auto substPatternType = getPatternType()->substituteImpl(astBuilder, subst, &diff); if (!diff) return this; // If some part of pattern type or captured type can be substituted into something else, // but not all of the captured types are resolved to concrete type packs yet, we will just // create a new ExpandType with the substituted pattern/capture types, instead of actually // expanding into a concrete type pack. (*ioDiff)++; return astBuilder->getExpandType( as(substPatternType), capturedPacks.getArrayView().arrayView); } else { // All type pack parameters are now concrete type packs, so we can construct a concrete type // pack by substituting the pattern type with each element of the captured type pack. ShortList expandedTypes; SLANG_ASSERT(capturedPacks.getCount() != 0); for (int i = 0; i < (int)concreteTypePacks[0]->getTypeCount(); i++) { subst.packExpansionIndex = i; auto substElementType = getPatternType()->substituteImpl(astBuilder, subst, &diff); expandedTypes.add(as(substElementType)); } if (!diff) return this; (*ioDiff)++; return astBuilder->getTypePack(expandedTypes.getArrayView().arrayView); } } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConcreteTypePack !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ConcreteTypePack::_toTextOverride(StringBuilder& out) { for (Index i = 0; i < getTypeCount(); i++) { if (i != 0) out << ", "; getElementType(i)->toText(out); } } Type* ConcreteTypePack::_createCanonicalTypeOverride() { ShortList canonicalElementTypes; for (Index i = 0; i < getTypeCount(); i++) { canonicalElementTypes.add(getElementType(i)->getCanonicalType()); } return getCurrentASTBuilder()->getTypePack(canonicalElementTypes.getArrayView().arrayView); } Val* ConcreteTypePack::_substituteImplOverride( ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; ShortList substElementTypes; for (Index i = 0; i < getTypeCount(); i++) { auto substType = as(getElementType(i)->substituteImpl(astBuilder, subst, &diff)); if (auto typePack = as(substType)) { // Unwrap the ConcreteTypePack and add each element as a parameter for (Index j = 0; j < typePack->getTypeCount(); ++j) { substElementTypes.add(typePack->getElementType(j)); } } else { substElementTypes.add(substType); } } if (!diff) return this; (*ioDiff)++; return getCurrentASTBuilder()->getTypePack(substElementTypes.getArrayView().arrayView); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialType::_toTextOverride(StringBuilder& out) { out << getDeclRef() << toSlice(".This"); } Type* ExtractExistentialType::_createCanonicalTypeOverride() { return this; } Val* ExtractExistentialType::_substituteImplOverride( ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); auto substOriginalInterfaceType = getOriginalInterfaceType()->substituteImpl(astBuilder, subst, &diff); auto substOriginalInterfaceDeclRef = getOriginalInterfaceDeclRef().substituteImpl(astBuilder, subst, &diff); if (!diff) return this; (*ioDiff)++; ExtractExistentialType* substValue = astBuilder->getOrCreate( substDeclRef, as(substOriginalInterfaceType), substOriginalInterfaceDeclRef); return substValue; } SubtypeWitness* ExtractExistentialType::getSubtypeWitness() { if (auto cachedValue = this->cachedSubtypeWitness) return cachedValue; ExtractExistentialSubtypeWitness* openedWitness = getCurrentASTBuilder()->getOrCreate( this, getOriginalInterfaceType(), getDeclRef()); this->cachedSubtypeWitness = openedWitness; return openedWitness; } DeclRef ExtractExistentialType::getThisTypeDeclRef() { if (auto cachedValue = this->cachedThisTypeDeclRef) return cachedValue; auto interfaceDecl = getOriginalInterfaceDeclRef().getDecl(); SubtypeWitness* openedWitness = getSubtypeWitness(); ThisTypeDecl* thisTypeDecl = interfaceDecl->getThisTypeDecl(); SLANG_ASSERT(thisTypeDecl); DeclRef specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl).as(); this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef; return specialiedInterfaceDeclRef; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExistentialSpecializedType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExistentialSpecializedType::_toTextOverride(StringBuilder& out) { out << toSlice("__ExistentialSpecializedType(") << getBaseType(); for (Index i = 0; i < getArgCount(); i++) { out << toSlice(", ") << getArg(i).val; } out << toSlice(")"); } static Val* _getCanonicalValue(Val* val) { if (!val) return nullptr; if (auto type = as(val)) { return type->getCanonicalType(); } // TODO: We may eventually need/want some sort of canonicalization // for non-type values, but for now there is nothing to do. return val; } Type* ExistentialSpecializedType::_createCanonicalTypeOverride() { ExpandedSpecializationArgs newArgs; for (Index ii = 0; ii < getArgCount(); ++ii) { auto arg = getArg(ii); ExpandedSpecializationArg canArg; canArg.val = _getCanonicalValue(arg.val); canArg.witness = _getCanonicalValue(arg.witness); newArgs.add(canArg); } ExistentialSpecializedType* canType = getCurrentASTBuilder()->getOrCreate( getBaseType()->getCanonicalType(), newArgs); return canType; } static Val* _substituteImpl(ASTBuilder* astBuilder, Val* val, SubstitutionSet subst, int* ioDiff) { if (!val) return nullptr; return val->substituteImpl(astBuilder, subst, ioDiff); } Val* ExistentialSpecializedType::_substituteImplOverride( ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; auto substBaseType = as(getBaseType()->substituteImpl(astBuilder, subst, &diff)); ExpandedSpecializationArgs substArgs; for (Index ii = 0; ii < getArgCount(); ++ii) { auto arg = getArg(ii); ExpandedSpecializationArg substArg; substArg.val = _substituteImpl(astBuilder, arg.val, subst, &diff); substArg.witness = _substituteImpl(astBuilder, arg.witness, subst, &diff); substArgs.add(substArg); } if (!diff) return this; (*ioDiff)++; ExistentialSpecializedType* substType = astBuilder->getOrCreate(substBaseType, substArgs); return substType; } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRef ThisType::getInterfaceDeclRef() { return DeclRef(getDeclRefBase()->getParent()).template as(); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AndType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void AndType::_toTextOverride(StringBuilder& out) { out << getLeft() << toSlice(" & ") << getRight(); } Type* AndType::_createCanonicalTypeOverride() { // TODO: proper canonicalization of an `&` type relies on // several different things: // // * We need to re-associate types that might involve // nesting of `&`, such as `(A & B) & (C & D)`, into // a canonical form where the nesting is consistent // (i.e., always left- or right-associative). // // * We need to commute types so that they are in a // consistent order, so that `A & B` and `B & A` both // result in the same canonicalization. This requirement // implies that we must invent a total order on types. // // * We need to canonicalize `&` types where one of the // elements might be implied by another. E.g., if we // have `interface IDerived : IBase`, then a type like // `IDerived & IBase` is equivalent to just `IDerived` // because the presence of an `IBase` conformance is // implied. A special case of the above is the possibility // of duplicates in the list of types (e.g., `A & B & A`). // // * The previous requirement raises the problem that // the relationships between `interface`s might either // evolve over time, or be subject to `extension` // declarations in other modules. The canonicalization // algorithm must be clear about what information it // is allowed to make use of, as this can/will affect // binary interfaces (via mangled names). // // We are going to completely ignore these issues for // right now, in the name of getting something up and running. // auto canLeft = getLeft()->getCanonicalType(); auto canRight = getRight()->getCanonicalType(); auto canType = getCurrentASTBuilder()->getAndType(canLeft, canRight); return canType; } Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; auto substLeft = as(getLeft()->substituteImpl(astBuilder, subst, &diff)); auto substRight = as(getRight()->substituteImpl(astBuilder, subst, &diff)); if (!diff) return this; (*ioDiff)++; auto substType = astBuilder->getAndType(substLeft, substRight); return substType; } // ModifiedType void ModifiedType::_toTextOverride(StringBuilder& out) { for (Index i = 0; i < getModifierCount(); i++) { getModifier(i)->toText(out); out.appendChar(' '); } getBase()->toText(out); } Type* ModifiedType::_createCanonicalTypeOverride() { List modifiers; for (Index i = 0; i < getModifierCount(); ++i) { auto modifier = this->getModifier(i); modifiers.add(modifier); } ModifiedType* canonical = getCurrentASTBuilder()->getOrCreate( getBase()->getCanonicalType(), modifiers.getArrayView()); return canonical; } Val* ModifiedType::_substituteImplOverride( ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; Type* substBase = as(getBase()->substituteImpl(astBuilder, subst, &diff)); List substModifiers; for (Index i = 0; i < getModifierCount(); ++i) { auto modifier = this->getModifier(i); auto substModifier = modifier->substituteImpl(astBuilder, subst, &diff); substModifiers.add(substModifier); } if (!diff) return this; *ioDiff = 1; ModifiedType* substType = getCurrentASTBuilder()->getOrCreate(substBase, substModifiers.getArrayView()); return substType; } BaseType BasicExpressionType::getBaseType() const { auto builtinType = getDeclRef().getDecl()->findModifier(); return builtinType->tag; } FeedbackType::Kind FeedbackType::getKind() const { auto magicMod = getDeclRef().getDecl()->findModifier(); return FeedbackType::Kind(magicMod->tag); } SlangResourceShape ResourceType::getBaseShape() { auto shape = _getGenericTypeArg(getDeclRefBase(), 1); if (as(shape)) return SLANG_TEXTURE_1D; else if (as(shape)) return SLANG_TEXTURE_2D; else if (as(shape)) return SLANG_TEXTURE_3D; else if (as(shape)) return SLANG_TEXTURE_CUBE; else if (as(shape)) return SLANG_TEXTURE_BUFFER; return SLANG_RESOURCE_NONE; } SlangResourceShape ResourceType::getShape() { auto baseShape = (SlangResourceShape)getBaseShape(); if (isArray()) baseShape = (SlangResourceShape)((uint32_t)baseShape | SLANG_TEXTURE_ARRAY_FLAG); if (isMultisample()) baseShape = (SlangResourceShape)((uint32_t)baseShape | SLANG_TEXTURE_MULTISAMPLE_FLAG); if (isShadow()) baseShape = (SlangResourceShape)((uint32_t)baseShape | SLANG_TEXTURE_SHADOW_FLAG); if (isFeedback()) baseShape = (SlangResourceShape)((uint32_t)baseShape | SLANG_TEXTURE_FEEDBACK_FLAG); if (isCombined()) baseShape = (SlangResourceShape)((uint32_t)baseShape | SLANG_TEXTURE_COMBINED_FLAG); return baseShape; } bool ResourceType::isArray() { auto isArray = _getGenericTypeArg(this, kCoreModule_TextureIsArrayParameterIndex); if (auto constIntVal = as(isArray)) return constIntVal->getValue() != 0; return false; } bool ResourceType::isMultisample() { auto isMS = _getGenericTypeArg(this, kCoreModule_TextureIsMultisampleParameterIndex); if (auto constIntVal = as(isMS)) return constIntVal->getValue() != 0; return false; } bool ResourceType::isShadow() { auto isShadow = _getGenericTypeArg(this, kCoreModule_TextureIsShadowParameterIndex); if (auto constIntVal = as(isShadow)) return constIntVal->getValue() != 0; return false; } bool ResourceType::isFeedback() { auto access = _getGenericTypeArg(this, kCoreModule_TextureAccessParameterIndex); if (auto constIntVal = as(access)) return constIntVal->getValue() == kCoreModule_ResourceAccessFeedback; return false; } bool ResourceType::isCombined() { auto combined = _getGenericTypeArg(this, kCoreModule_TextureIsCombinedParameterIndex); if (auto constIntVal = as(combined)) return constIntVal->getValue() != 0; return false; } Type* SubpassInputType::getElementType() { return as(_getGenericTypeArg(this, 0)); } bool SubpassInputType::isMultisample() { auto isMS = _getGenericTypeArg(this, 1); if (auto constIntVal = as(isMS)) return constIntVal->getValue() != 0; return false; } SlangResourceAccess ResourceType::getAccess() { auto access = _getGenericTypeArg(this, kCoreModule_TextureAccessParameterIndex); if (auto constIntVal = as(access)) { switch (constIntVal->getValue()) { case kCoreModule_ResourceAccessReadOnly: return SLANG_RESOURCE_ACCESS_READ; case kCoreModule_ResourceAccessReadWrite: return SLANG_RESOURCE_ACCESS_READ_WRITE; case kCoreModule_ResourceAccessWriteOnly: return SLANG_RESOURCE_ACCESS_WRITE; case kCoreModule_ResourceAccessRasterizerOrdered: return SLANG_RESOURCE_ACCESS_RASTER_ORDERED; case kCoreModule_ResourceAccessFeedback: return SLANG_RESOURCE_ACCESS_FEEDBACK; default: break; } } return SLANG_RESOURCE_ACCESS_NONE; } SamplerStateFlavor SamplerStateType::getFlavor() const { auto magicMod = getDeclRef().getDecl()->findModifier(); return SamplerStateFlavor(magicMod->tag); } Type* BuiltinGenericType::getElementType() const { return as(_getGenericTypeArg(getDeclRefBase(), 0)); } Type* ResourceType::getElementType() { return as(_getGenericTypeArg(this, 0)); } void ResourceType::_toTextOverride(StringBuilder& out) { auto tryPrintSimpleName = [&](String& outString) -> bool { StringBuilder resultSB; auto access = getAccess(); switch (access) { case SLANG_RESOURCE_ACCESS_READ: break; case SLANG_RESOURCE_ACCESS_READ_WRITE: resultSB << "RW"; ; break; case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: resultSB << "RasterizerOrdered"; break; case SLANG_RESOURCE_ACCESS_FEEDBACK: resultSB << "Feedback"; break; default: return false; } auto combined = as(_getGenericTypeArg(this, 7)); auto shapeVal = _getGenericTypeArg(this, 1); if (!as(shapeVal)) return false; auto shape = getBaseShape(); if (!combined) return false; if (combined->getValue() != 0) resultSB << "Sampler"; else { if (shape == SLANG_TEXTURE_BUFFER) resultSB << "Buffer"; else resultSB << "Texture"; } switch (shape) { case SLANG_TEXTURE_1D: resultSB << "1D"; break; case SLANG_TEXTURE_2D: resultSB << "2D"; break; case SLANG_TEXTURE_3D: resultSB << "3D"; break; case SLANG_TEXTURE_CUBE: resultSB << "Cube"; break; } auto isArrayVal = as(_getGenericTypeArg(this, 2)); if (!isArrayVal) return false; if (isArray()) resultSB << "Array"; auto isMultisampleVal = as(_getGenericTypeArg(this, 3)); if (!isMultisampleVal) return false; if (isMultisample()) resultSB << "MS"; auto isShadowVal = as(_getGenericTypeArg(this, 6)); if (!isShadowVal) return false; if (isShadow()) return false; auto elementType = getElementType(); if (elementType) { resultSB << "<"; resultSB << elementType->toString(); auto sampleCount = _getGenericTypeArg(this, 4); if (auto constIntVal = as(sampleCount)) { if (constIntVal->getValue() != 0) resultSB << ", " << constIntVal->getValue(); } else { return false; } resultSB << ">"; } outString = resultSB.toString(); return true; }; String simpleName; if (tryPrintSimpleName(simpleName)) out << simpleName; else DeclRefType::_toTextOverride(out); } Val* TextureTypeBase::getSampleCount() { return as(_getGenericTypeArg(this, 4)); } Val* TextureTypeBase::getFormat() { return as(_getGenericTypeArg(this, 8)); } Type* removeParamDirType(Type* type) { for (auto paramDirType = as(type); paramDirType;) { type = paramDirType->getValueType(); paramDirType = as(type); } return type; } bool isNonCopyableType(Type* type) { auto declRefType = as(type); if (!declRefType) return false; if (declRefType->getDeclRef().getDecl()->findModifier()) return true; return false; } } // namespace Slang