diff options
25 files changed, 1254 insertions, 893 deletions
diff --git a/docs/user-guide/10-link-time-specialization.md b/docs/user-guide/10-link-time-specialization.md index 516fd19b5..e5bc97930 100644 --- a/docs/user-guide/10-link-time-specialization.md +++ b/docs/user-guide/10-link-time-specialization.md @@ -163,16 +163,9 @@ import common; export struct Sampler : ISampler = FooSampler; ``` -The `=` syntax is a syntactic sugar that expands to the following code: - -```csharp -export struct Sampler : ISampler -{ - FooSampler inner; - int getSampleCount() { return inner.getSampleCount(); } - float sample(int index) { return inner.sample(index); } -} -``` +The `=` syntax defines a typealias that allows `Sampler` to resolve to `FooSampler` at link-time. +Note that both the name and type conformance clauses must match exactly between an `export` and an `extern` declaration +for link-time types to resolve correctly. Link-time types can also be generic, and may conform to generic interfaces. When all these three modules are linked, we will produce a specialized shader that uses the `FooSampler`. @@ -196,17 +189,6 @@ void main(uint tid : SV_DispatchThreadID) } ``` -## Restrictions - -Unlike preprocessors, link-time constants and types can only be used in places where shader parameter layout cannot be -affected. This means that link-time constants and types are subject to the following restrictions: -- Link-time constants cannot be used to define array sizes. -- Link-time types are considered "incomplete" types. A struct or array type that has incomplete typed element is also an incomplete type. - Incomplete types cannot be used as `ConstantBuffer` or `ParameterBlock` element type, and cannot be used directly as the type of - a uniform variable. - -However it is allowed to use incomplete types as the element type of `StructuredBuffer` or `GLSLStorageBuffer`. - ## Using Precompiling Modules with the API In addition to using `slangc` for precompiling Slang modules, the `IModule` class provides a method to serialize itself to disk: diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html index b32ad71da..89ddf34d9 100644 --- a/docs/user-guide/toc.html +++ b/docs/user-guide/toc.html @@ -148,7 +148,6 @@ <li data-link="link-time-specialization#link-time-constants"><span>Link-time Constants</span></li> <li data-link="link-time-specialization#link-time-types"><span>Link-time Types</span></li> <li data-link="link-time-specialization#providing-default-settings"><span>Providing Default Settings</span></li> -<li data-link="link-time-specialization#restrictions"><span>Restrictions</span></li> <li data-link="link-time-specialization#using-precompiling-modules-with-the-api"><span>Using Precompiling Modules with the API</span></li> <li data-link="link-time-specialization#additional-remarks"><span>Additional Remarks</span></li> </ul> diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 40c55ec44..1a850da0d 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -386,8 +386,16 @@ class AggTypeDecl : public AggTypeDeclBase FIDDLE(...) FIDDLE() TypeTag typeTags = TypeTag::None; - // Used if this type declaration is a wrapper, i.e. struct FooWrapper:IFoo = Foo; - TypeExp wrappedType; + // When user defines an agg type in the syntax of + // `struct FooAlias : IFoo = Foo;` + // The user is defining a link-time type alias. In contrast + // to an ordinary typealias, a link-time alias is not folded in + // the front-end, and resolved during linking. + // `aliasedType` is used to store the alised type (in this case `Foo`) + // when the agg type decl is declared in the link-time alias syntax. + // + TypeExp aliasedType; + bool hasBody = true; void unionTagsWith(TypeTag other); @@ -506,6 +514,13 @@ class InheritanceDecl : public TypeConstraintDecl // this inheritance declaration. FIDDLE() RefPtr<WitnessTable> witnessTable; + // If the inheritance decl is in a link-time type declaration + // (e.g. `export struct Foo : IFoo = FooImpl;`), then we will + // store the witness that `FooImpl:IFoo` here. + // TODO: If we made `WitnessTable` a `Val`, we should be able + // to unify these two cases. + FIDDLE() Witness* witnessVal = nullptr; + // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return base; } }; @@ -1015,4 +1030,8 @@ void addSiblingScopeForContainerDecl( ContainerDecl* source); void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, ContainerDecl* source); +// Cast `decl` to a valid `ContainerDecl*` if its members will become global scope symbols after +// lowering to IR. This currently includes: `NamespaceDecl`, `ModuleDecl` and `FileDecl`. +ContainerDecl* isStaticScopeDecl(Decl* decl); + } // namespace Slang diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 379de0560..dfb203153 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -543,7 +543,7 @@ void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl) visitDecl(member); } if (auto aggTypeDecl = as<AggTypeDecl>(decl)) - visitExpr(aggTypeDecl->wrappedType.exp); + visitExpr(aggTypeDecl->aliasedType.exp); } for (auto modifier : decl->modifiers) { diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 9dd481acb..3ccbd4ff7 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1260,6 +1260,7 @@ FIDDLE() namespace Slang Type* Ptr() { return type; } operator Type*() { return type; } Type* operator->() { return Ptr(); } + explicit operator bool() const { return type != nullptr; } ThisType& operator=(const ThisType& rhs) = default; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6a4e3668f..fa31c54bd 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2553,24 +2553,17 @@ void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) structDecl->addTag(TypeTag::Incomplete); } - // Slang supports a convenient syntax to create a wrapper type from + // Slang supports a convenient syntax to create a link-time aliased type from // an existing type that implements a given interface. For example, // the user can write: struct FooWrapper:IFoo = Foo; - // In this case we will synthesize the FooWrapper type with an inner - // member of type `Foo`, and use it to implement all requirements of - // IFoo. - // If this is a wrapper struct, synthesize the inner member now. - if (structDecl->wrappedType.exp) - { - structDecl->wrappedType = CheckProperType(structDecl->wrappedType); - auto member = m_astBuilder->create<VarDecl>(); - member->type = structDecl->wrappedType; - member->nameAndLoc.name = getName("inner"); - member->nameAndLoc.loc = structDecl->wrappedType.exp->loc; - member->loc = member->nameAndLoc.loc; - addModifier(member, m_astBuilder->create<SynthesizedModifier>()); - structDecl->addMember(member); + // In this case we need to check the aliasedType expr. + if (structDecl->aliasedType.exp) + { + SemanticsVisitor visitor(withDeclToExcludeFromLookup(structDecl)); + structDecl->aliasedType = visitor.CheckProperType(structDecl->aliasedType); + structDecl->addTag(getTypeTags(structDecl->aliasedType)); } + checkVisibility(structDecl); } @@ -5348,16 +5341,6 @@ void SemanticsVisitor::_addMethodWitness( witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); } -static bool isWrapperTypeDecl(Decl* decl) -{ - if (auto aggTypeDecl = as<AggTypeDecl>(decl)) - { - if (aggTypeDecl->wrappedType) - return true; - } - return false; -} - // Is it allowed to have an interface method parameter whose direction is `reqDir`, and an // implementing method parameter whose direction is `implDir`? // @@ -5487,16 +5470,14 @@ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( // With the big picture spelled out, we can settle into // the work of constructing our synthesized method. // - bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); - // First, we check that the differentiabliity of the method matches the requirement, // and we don't attempt to synthesize a method if they don't match. if (lookupResult.isValid()) { - if (!isInWrapperType && getShared()->getFuncDifferentiableLevel( - as<FunctionDeclBase>(lookupResult.item.declRef.getDecl())) < - getShared()->getFuncDifferentiableLevel( - as<FunctionDeclBase>(requiredMemberDeclRef.getDecl()))) + if (getShared()->getFuncDifferentiableLevel( + as<FunctionDeclBase>(lookupResult.item.declRef.getDecl())) < + getShared()->getFuncDifferentiableLevel( + as<FunctionDeclBase>(requiredMemberDeclRef.getDecl()))) { return false; } @@ -5524,25 +5505,7 @@ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( auto baseOverloadedExpr = m_astBuilder->create<OverloadedExpr>(); baseOverloadedExpr->name = requiredMemberDeclRef.getDecl()->getName(); - if (isInWrapperType) - { - auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); - baseOverloadedExpr->lookupResult2 = lookUpMember( - m_astBuilder, - this, - baseOverloadedExpr->name, - aggTypeDecl->wrappedType.type, - aggTypeDecl->ownedScope, - LookupMask::Default, - LookupOptions::IgnoreBaseInterfaces); - addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>()); - - synFuncDecl->parentDecl = aggTypeDecl; - } - else - { - baseOverloadedExpr->lookupResult2 = lookupResult; - } + baseOverloadedExpr->lookupResult2 = lookupResult; // Non-static methods cannot implement static methods, remove them. if (requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>()) @@ -5557,24 +5520,7 @@ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( // if (synThis) { - if (isInWrapperType) - { - // If this is a wrapper type, then use the inner - // object as the actual this parameter for the redirected - // call. - auto innerExpr = m_astBuilder->create<VarExpr>(); - innerExpr->scope = synThis->scope; - innerExpr->name = getName("inner"); - baseOverloadedExpr->base = CheckExpr(innerExpr); - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); - bodyVisitor.maybeRegisterDifferentiableType( - m_astBuilder, - baseOverloadedExpr->base->type); - } - else - { - baseOverloadedExpr->base = synThis; - } + baseOverloadedExpr->base = synThis; } @@ -5823,8 +5769,7 @@ bool SemanticsVisitor::trySynthesizeConstructorRequirementWitness( bool isDefaultInitializableType = requiredMemberDeclRef.getParent() == getASTBuilder()->getDefaultInitializableTypeInterfaceDecl(); - bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); - if (!isInWrapperType && !isDefaultInitializableType && !satisfyingMemberLookupResult.isValid()) + if (!isDefaultInitializableType && !satisfyingMemberLookupResult.isValid()) { return false; } @@ -5846,53 +5791,7 @@ bool SemanticsVisitor::trySynthesizeConstructorRequirementWitness( auto seqStmt = m_astBuilder->create<SeqStmt>(); ctorDecl->body = seqStmt; - - if (isInWrapperType) - { - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(ctorDecl)); - bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, context->conformingType); - - if (auto varDecl = context->parentDecl->findFirstDirectMemberDeclOfType<VarDeclBase>()) - { - auto varExpr = m_astBuilder->create<VarExpr>(); - varExpr->scope = ctorDecl->ownedScope; - varExpr->name = varDecl->getName(); - auto checkedVarExpr = CheckTerm(varExpr); - if (!checkedVarExpr) - return false; - if (as<ErrorType>(checkedVarExpr->type.type)) - return false; - auto assign = m_astBuilder->create<AssignExpr>(); - assign->left = checkedVarExpr; - auto temp = m_astBuilder->create<InvokeExpr>(); - auto lookupResult = lookUpMember( - m_astBuilder, - this, - ctorName, - varDecl->type.type, - ctorDecl->ownedScope, - LookupMask::Function, - LookupOptions::IgnoreBaseInterfaces); - temp->functionExpr = createLookupResultExpr( - ctorName, - lookupResult, - nullptr, - context->parentDecl->loc, - nullptr); - temp->arguments.addRange(synArgs); - auto resolvedVar = ResolveInvoke(temp); - if (!resolvedVar) - return false; - assign->right = resolvedVar; - assign->type = m_astBuilder->getVoidType(); - bodyVisitor.maybeRegisterDifferentiableType(m_astBuilder, varDecl->type.type); - - auto stmt = m_astBuilder->create<ExpressionStmt>(); - stmt->expression = assign; - seqStmt->stmts.add(stmt); - } - } - else if (synArgs.getCount()) + if (synArgs.getCount()) { // The body of our synthesized method is going to try to // make a ctor call with the specified arguments (e.g., @@ -5965,12 +5864,6 @@ bool SemanticsVisitor::trySynthesizePropertyRequirementWitness( DeclRef<PropertyDecl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable) { - if (isWrapperTypeDecl(context->parentDecl)) - return trySynthesizeWrapperTypePropertyRequirementWitness( - context, - requiredMemberDeclRef, - witnessTable); - // The situation here is that the context of an inheritance // declaration didn't provide an exact match for a required // property. E.g.: @@ -6130,244 +6023,6 @@ bool SemanticsVisitor::trySynthesizePropertyRequirementWitness( return true; } -bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( - ConformanceCheckingContext* context, - DeclRef<PropertyDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) -{ - // We are synthesizing a property requirement for a wrapper type: - // - // interface IFoo { property value : int { get; set; } } - // struct Foo : IFoo = FooImpl; - // - // We need to synthesize Foo to: - // - // struct Foo : IFoo - // { - // FooImpl inner; - // property value : int { get { return inner.value; } - // set { inner.value = newValue; } - // } - // } - // - // To do so, we need to grab the witness table of FooImpl:IFoo, and create - // wrapper property in Foo that forwards the accessors to the inner object. - // - // We get started by constructing a synthesized `PropertyDecl`. - // - auto synPropertyDecl = m_astBuilder->create<PropertyDecl>(); - synPropertyDecl->parentDecl = context->parentDecl; - - // Synthesize the property name with a prefix to avoid name clashing. - // - synPropertyDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - synPropertyDecl->nameAndLoc.name = - getName(String("$syn_property_") + getText(requiredMemberDeclRef.getName())); - - // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef<Decl> innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) - return false; - - for (auto requiredAccessorDeclRef : - getMembersOfType<AccessorDecl>(m_astBuilder, requiredMemberDeclRef)) - { - auto innerEntry = tryLookUpRequirementWitness( - m_astBuilder, - innerWitness, - requiredAccessorDeclRef.getDecl()); - if (innerEntry.getFlavor() != RequirementWitness::Flavor::declRef) - return false; - auto innerAccessorDeclRef = as<AccessorDecl>(innerEntry.getDeclRef()); - if (!innerAccessorDeclRef) - return false; - - // The synthesized accessor will be an AST node of the same class as - // the required accessor. - // - auto synAccessorDecl = (AccessorDecl*)m_astBuilder->createByNodeType( - requiredAccessorDeclRef.getDecl()->astNodeType); - synAccessorDecl->ownedScope = m_astBuilder->create<Scope>(); - synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; - synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); - - // The return type should be the same as the inner object's accessor return type. - // - synAccessorDecl->returnType.type = getResultType(m_astBuilder, innerAccessorDeclRef); - - // Similarly, our synthesized accessor will have parameters matching those of the inner - // accessor. - // - List<Expr*> synArgs; - for (auto innerParamDeclRef : getParameters(m_astBuilder, innerAccessorDeclRef)) - { - auto paramType = getType(m_astBuilder, innerParamDeclRef); - - // The synthesized parameter will ahve the same name and - // type as the parameter of the requirement. - // - auto synParamDecl = m_astBuilder->create<ParamDecl>(); - synParamDecl->nameAndLoc = innerParamDeclRef.getDecl()->nameAndLoc; - synParamDecl->type.type = paramType; - - // We need to add the parameter as a child declaration of - // the accessor we are building. - // - synAccessorDecl->addMember(synParamDecl); - - // For each paramter, we will create an argument expression - // to represent it in the body of the accessor. - // - auto synArg = m_astBuilder->create<VarExpr>(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - synArgs.add(synArg); - } - - // Now synthesize the body of the property accessor. - // The body of the accessor will depend on the class of the accessor - // we are synthesizing (e.g., `get` vs. `set`). - // - Stmt* synBodyStmt = nullptr; - auto propertyRef = m_astBuilder->create<MemberExpr>(); - propertyRef->scope = synAccessorDecl->ownedScope; - auto base = m_astBuilder->create<VarExpr>(); - base->scope = propertyRef->scope; - base->name = getName("inner"); - propertyRef->baseExpression = base; - innerProperty = innerAccessorDeclRef.getParent(); - propertyRef->name = requiredMemberDeclRef.getName(); - auto checkedPropertyRefExpr = CheckExpr(propertyRef); - - if (as<GetterDecl>(requiredAccessorDeclRef)) - { - auto synReturn = m_astBuilder->create<ReturnStmt>(); - synReturn->expression = checkedPropertyRefExpr; - - synBodyStmt = synReturn; - } - else if (as<SetterDecl>(requiredAccessorDeclRef)) - { - auto synAssign = m_astBuilder->create<AssignExpr>(); - synAssign->left = checkedPropertyRefExpr; - synAssign->right = synArgs[0]; - - auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign); - - auto synExprStmt = m_astBuilder->create<ExpressionStmt>(); - synExprStmt->expression = synCheckedAssign; - - synBodyStmt = synExprStmt; - } - else - { - // While there are other kinds of accessors than `get` and `set`, - // those are currently only reserved for the internal use in the core module. - // We will not bother with synthesis for those cases. - // - return false; - } - - addModifier(synAccessorDecl, m_astBuilder->create<ForceInlineAttribute>()); - synAccessorDecl->body = synBodyStmt; - - synPropertyDecl->addMember(synAccessorDecl); - - // Register the synthesized accessor. - // - witnessTable->add( - requiredAccessorDeclRef.getDecl(), - RequirementWitness(makeDeclRef(synAccessorDecl))); - } - - // The type of our synthesized property will be the same as the inner property. - // - auto propertyType = getType(m_astBuilder, as<PropertyDecl>(innerProperty)); - synPropertyDecl->type.type = propertyType; - - // The visibility of synthesized decl should be the same as the inner requirement - if (innerProperty.getDecl()->findModifier<VisibilityModifier>()) - { - auto vis = getDeclVisibility(innerProperty.getDecl()); - addVisibilityModifier(synPropertyDecl, vis); - } - - context->parentDecl->addMember(synPropertyDecl); - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(makeDeclRef(synPropertyDecl))); - return true; -} - -bool SemanticsVisitor::trySynthesizeAssociatedTypeRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& inLookupResult, - DeclRef<AssocTypeDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) -{ - SLANG_UNUSED(inLookupResult); - - // The only case we can synthesize for now is when the conformant type - // is a wrapper type. - if (!isWrapperTypeDecl(context->parentDecl)) - return false; - auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); - auto lookupResult = lookUpMember( - m_astBuilder, - this, - requiredMemberDeclRef.getName(), - aggTypeDecl->wrappedType.type, - aggTypeDecl->ownedScope, - LookupMask::Default, - LookupOptions::IgnoreBaseInterfaces); - if (!lookupResult.isValid() || lookupResult.isOverloaded()) - return false; - auto assocType = DeclRefType::create(m_astBuilder, lookupResult.item.declRef); - witnessTable->add(requiredMemberDeclRef.getDecl(), assocType); - for (auto typeConstraintDecl : - getMembersOfType<TypeConstraintDecl>(m_astBuilder, requiredMemberDeclRef)) - { - auto witness = tryGetSubtypeWitness(assocType, getSup(m_astBuilder, typeConstraintDecl)); - if (!witness) - return false; - witnessTable->add(typeConstraintDecl.getDecl(), witness); - } - return true; -} - -bool SemanticsVisitor::trySynthesizeAssociatedConstantRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& inLookupResult, - DeclRef<VarDeclBase> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) -{ - SLANG_UNUSED(inLookupResult); - - // The only case we can synthesize for now is when the conformant type - // is a wrapper type, i.e. - // struct Foo:IFoo = FooImpl; - if (!isWrapperTypeDecl(context->parentDecl)) - return false; - - // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef<Decl> innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) - return false; - - auto witness = - tryLookUpRequirementWitness(m_astBuilder, innerWitness, requiredMemberDeclRef.getDecl()); - if (witness.getFlavor() != RequirementWitness::Flavor::val) - return false; - witnessTable->add(requiredMemberDeclRef.getDecl(), witness.getVal()); - return true; -} - bool SemanticsVisitor::synthesizeAccessorRequirements( ConformanceCheckingContext* context, DeclRef<ContainerDecl> requiredMemberDeclRef, @@ -6609,92 +6264,12 @@ bool SemanticsVisitor::synthesizeAccessorRequirements( return true; } -bool SemanticsVisitor::trySynthesizeWrapperTypeSubscriptRequirementWitness( - ConformanceCheckingContext* context, - DeclRef<SubscriptDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) -{ - // We are synthesizing the subscript requirement for a wrapper type: - // struct Wrapper - // { - // Inner inner; - // subscript(int index)->int { get { return inner[index]; } - // set { inner[index] = newValue; } - // } - // } - // - // // Find the witness that FooImpl : IFoo. - auto aggTypeDecl = as<AggTypeDecl>(context->parentDecl); - auto innerType = aggTypeDecl->wrappedType.type; - DeclRef<Decl> innerProperty; - auto innerWitness = tryGetSubtypeWitness(innerType, witnessTable->baseType); - if (!innerWitness) - return false; - // - List<Expr*> synArgs; - ThisExpr* synThis; - auto synSubscriptDecl = synthesizeMethodSignatureForRequirementWitness( - context, - requiredMemberDeclRef, - synArgs, - synThis); - auto declType = getType(m_astBuilder, getDefaultDeclRef(synSubscriptDecl).as<SubscriptDecl>()); - synThis->checked = true; - - // Form a `this[args...]` expression that we will use to coerce from - // in the synthesized subscript accessors. - // - DiagnosticSink tempSink(getSourceManager(), nullptr); - SemanticsVisitor subVisitor(withSink(&tempSink)); - auto base = m_astBuilder->create<VarExpr>(); - base->scope = synThis->scope; - base->name = getName("inner"); - - IndexExpr* indexExpr = m_astBuilder->create<IndexExpr>(); - indexExpr->baseExpression = base; - indexExpr->indexExprs = _Move(synArgs); - auto synBaseStorageExpr = subVisitor.CheckTerm(indexExpr); - - if (tempSink.getErrorCount() != 0) - return false; - - // Our synthesized subscript will have an accessor declaration for - // each accessor of the requirement. - // - bool canSynAccessors = synthesizeAccessorRequirements( - context, - requiredMemberDeclRef, - declType, - synBaseStorageExpr, - synSubscriptDecl, - witnessTable); - if (!canSynAccessors) - return false; - - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier<VisibilityModifier>()) - { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(synSubscriptDecl, visibility); - } - - return true; -} - bool SemanticsVisitor::trySynthesizeSubscriptRequirementWitness( ConformanceCheckingContext* context, const LookupResult& lookupResult, DeclRef<SubscriptDecl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable) { - if (isWrapperTypeDecl(context->parentDecl)) - return trySynthesizeWrapperTypeSubscriptRequirementWitness( - context, - requiredMemberDeclRef, - witnessTable); - // The situation here is that the context of an inheritance // declaration didn't provide an exact match for a required // subscript. E.g.: @@ -6926,21 +6501,13 @@ bool SemanticsVisitor::trySynthesizeRequirementWitness( } else { - return trySynthesizeAssociatedTypeRequirementWitness( - context, - lookupResult, - requiredAssocTypeDeclRef, - witnessTable); + return false; } } if (auto requiredConstantDeclRef = requiredMemberDeclRef.as<VarDeclBase>()) { - return trySynthesizeAssociatedConstantRequirementWitness( - context, - lookupResult, - requiredConstantDeclRef, - witnessTable); + return false; } if (auto requiredCtor = requiredMemberDeclRef.as<ConstructorDecl>()) @@ -7543,54 +7110,51 @@ bool SemanticsVisitor::findWitnessForInterfaceRequirement( // lookup results that might be usable, but not as-is. // LookupResult lookupResult; - if (!isWrapperTypeDecl(context->parentDecl)) + lookupResult = lookUpMember( + m_astBuilder, + this, + name, + subType, + nullptr, + LookupMask::Default, + LookupOptions::IgnoreBaseInterfaces); + + if (!lookupResult.isValid()) { - lookupResult = lookUpMember( - m_astBuilder, - this, - name, - subType, - nullptr, - LookupMask::Default, - LookupOptions::IgnoreBaseInterfaces); + // If we failed to look up a member with the name of the + // requirement, it may be possible that we can still synthesis the + // implementation if this is one of the known builtin requirements, + // or if the interface method contains a default impl. + // Otherwise, report diagnostic now. - if (!lookupResult.isValid()) + if (requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() || + (requiredMemberDeclRef.as<GenericDecl>() && + getInner(requiredMemberDeclRef.as<GenericDecl>()) + ->hasModifier<BuiltinRequirementModifier>())) + { + } + else if ( + requiredMemberDeclRef.as<SubscriptDecl>() && + (as<ArrayExpressionType>(context->conformingType) || + as<VectorExpressionType>(context->conformingType) || + as<MatrixExpressionType>(context->conformingType))) { - // If we failed to look up a member with the name of the - // requirement, it may be possible that we can still synthesis the - // implementation if this is one of the known builtin requirements, - // or if the interface method contains a default impl. - // Otherwise, report diagnostic now. - - if (requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() || - (requiredMemberDeclRef.as<GenericDecl>() && - getInner(requiredMemberDeclRef.as<GenericDecl>()) - ->hasModifier<BuiltinRequirementModifier>())) - { - } - else if ( - requiredMemberDeclRef.as<SubscriptDecl>() && - (as<ArrayExpressionType>(context->conformingType) || - as<VectorExpressionType>(context->conformingType) || - as<MatrixExpressionType>(context->conformingType))) - { - } - else if (hasDefaultImpl(requiredMemberDeclRef)) - { - } - else - { - getSink()->diagnose( - inheritanceDecl, - Diagnostics::typeDoesntImplementInterfaceRequirement, - subType, - requiredMemberDeclRef); - getSink()->diagnose( - requiredMemberDeclRef, - Diagnostics::seeDeclarationOf, - requiredMemberDeclRef); - return false; - } + } + else if (hasDefaultImpl(requiredMemberDeclRef)) + { + } + else + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::typeDoesntImplementInterfaceRequirement, + subType, + requiredMemberDeclRef); + getSink()->diagnose( + requiredMemberDeclRef, + Diagnostics::seeDeclarationOf, + requiredMemberDeclRef); + return false; } } if (lookupResult.isOverloaded()) @@ -7649,10 +7213,6 @@ bool SemanticsVisitor::findWitnessForInterfaceRequirement( // code required to handle all the conversions that might be // required on `this`. // - // Another situation that will get us here is that we are dealing with - // a wrapper type (struct Foo:IFoo=FooImpl), and we will synthesize - // wrappers that redirects the call into the inner element. - // MethodWitnessSynthesisFailureDetails failureDetails = {}; if (trySynthesizeRequirementWitness( context, @@ -8056,6 +7616,32 @@ bool SemanticsVisitor::checkConformance( // code to work. return true; } + + // If sub type is a link-time-resolved wrapper type (e.g. `extern struct Foo : IFoo = + // FooImpl;`), fill in `inheritanceDecl->witnessVal` with a Witness val that shows `FooImpl + // : IFoo`. + auto aggTypeDecl = as<AggTypeDecl>(declRef.getDecl()); + + if (aggTypeDecl && aggTypeDecl->aliasedType) + { + auto witness = tryGetSubtypeWitness(aggTypeDecl->aliasedType, superType); + if (witness) + { + inheritanceDecl->witnessVal = witness; + } + else + { + if (!as<ErrorType>(aggTypeDecl->aliasedType)) + { + getSink()->diagnose( + inheritanceDecl, + Diagnostics::typeArgumentDoesNotConformToInterface, + aggTypeDecl->aliasedType, + superType); + } + } + return witness != nullptr; + } } // Look at the type being inherited from, and validate diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 511834cef..b8e68e28b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -290,6 +290,13 @@ void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, Cont destScope->nextSibling = subScope; } +ContainerDecl* isStaticScopeDecl(Decl* decl) +{ + if (as<NamespaceDeclBase>(decl) || as<FileDecl>(decl)) + return as<ContainerDecl>(decl); + return nullptr; +} + void SemanticsVisitor::diagnoseDeprecatedDeclRefUsage( DeclRef<Decl> declRef, SourceLoc loc, diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index d5d9c2372..599eed12c 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2009,33 +2009,12 @@ public: DeclRef<PropertyDecl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); - bool trySynthesizeWrapperTypePropertyRequirementWitness( - ConformanceCheckingContext* context, - DeclRef<PropertyDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable); - bool trySynthesizeSubscriptRequirementWitness( ConformanceCheckingContext* context, const LookupResult& lookupResult, DeclRef<SubscriptDecl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); - bool trySynthesizeWrapperTypeSubscriptRequirementWitness( - ConformanceCheckingContext* context, - DeclRef<SubscriptDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable); - - bool trySynthesizeAssociatedTypeRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef<AssocTypeDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable); - - bool trySynthesizeAssociatedConstantRequirementWitness( - ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef<VarDeclBase> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable); /// Attempt to synthesize a declartion that can satisfy `requiredMemberDeclRef` using /// `lookupResult`. diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 4b56cc52f..7032f2016 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -677,4 +677,5 @@ return { ["Decoration.DisableCopyEliminationDecoration"] = 673, ["Decoration.TempCallArgImmutableVar"] = 674, ["CastResourceToDescriptorHandle"] = 675, + ["SymbolAlias"] = 676, } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 7b27f0b56..42db9cb44 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3903,6 +3903,8 @@ public: return (IRMetalMeshType*)getType(kIROp_MetalMeshType, 5, ops); } + IRInst* emitSymbolAlias(IRInst* aliasedSymbol); + IRInst* emitDebugSource( UnownedStringSlice fileName, UnownedStringSlice source, diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 8b8515424..4093db8fa 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -479,6 +479,12 @@ local insts = { module = { struct_name = "ModuleInst", parent = true }, }, { block = { parent = true } }, + + -- A global inst representing an alias of another symbol, under a different mangled name. + -- This inst should be completely eliminated after linking, with its references replaced + -- to use the canonical symbol being aliased. + { SymbolAlias = { min_operands = 1 } }, + -- IRConstant { Constant = { diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index c46c57043..8ba1f354d 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -302,6 +302,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_GlobalGenericParam: case kIROp_WitnessTable: case kIROp_InterfaceType: + case kIROp_SymbolAlias: return cloneGlobalValue(this, originalValue); case kIROp_BoolLit: @@ -346,7 +347,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) return builder->getVoidValue(); } break; - default: { // In the default case, assume that we have some sort of "hoistable" @@ -1411,7 +1411,10 @@ IRInst* cloneInst( builder, cast<IRGlobalGenericParam>(originalInst), originalValues); - + case kIROp_SymbolAlias: + // If we encounter a symbol alias, we want to clone + // the value it refers to instead of the alias itself. + return context->maybeCloneValue(cast<IRSymbolAlias>(originalInst)->getOperand(0)); default: break; } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index d114a9a40..ebaebcc8a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3394,6 +3394,11 @@ IRInst* IRBuilder::emitOutImplicitCast(IRInst* type, IRInst* value) { return emitIntrinsicInst((IRType*)type, kIROp_OutImplicitCast, 1, &value); } +IRInst* IRBuilder::emitSymbolAlias(IRInst* aliasedSymbol) +{ + return emitIntrinsicInst(aliasedSymbol->getFullType(), kIROp_SymbolAlias, 1, &aliasedSymbol); +} + IRInst* IRBuilder::emitDebugSource( UnownedStringSlice fileName, UnownedStringSlice source, diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index cfb8e4d61..0c3f5f7c1 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -855,7 +855,7 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) if (auto aggTypeDecl = as<AggTypeDecl>(container)) { ASTLookupExprVisitor visitor(&context); - if (visitor.dispatchIfNotNull(aggTypeDecl->wrappedType.exp)) + if (visitor.dispatchIfNotNull(aggTypeDecl->aliasedType.exp)) return true; } } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 12545ad0d..d3f7d4e4c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8593,12 +8593,48 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl) { + // If the inheritance decl is nested inside a link-time type alias declaration, + // e.g. in `export struct Foo:IFoo = FooImpl`, + // then we need to emit a symbo alias to the `FooImpl:IFoo`. + // + auto parentDecl = inheritanceDecl->parentDecl; + auto aggTypeParentDecl = as<AggTypeDecl>(parentDecl); + if (aggTypeParentDecl && aggTypeParentDecl->aliasedType.type && inheritanceDecl->witnessVal) + { + NestedContext nested(this); + auto subBuilder = nested.getBuilder(); + auto subContext = nested.getContext(); + auto outerGeneric = emitOuterGenerics(subContext, inheritanceDecl, inheritanceDecl); + + auto wrappedWitness = lowerVal(subContext, inheritanceDecl->witnessVal); + IRInst* alias = nullptr; + if (outerGeneric) + { + alias = finishOuterGenerics(subBuilder, wrappedWitness.val, outerGeneric); + } + else + { + alias = getBuilder()->emitSymbolAlias(wrappedWitness.val); + } + auto mangledName = getMangledNameForConformanceWitness( + context->astBuilder, + parentDecl, + inheritanceDecl->base.type); + bool explicitExtern = false; + if (isImportedDecl(context, parentDecl, explicitExtern)) + getBuilder()->addImportDecoration(alias, mangledName.getUnownedSlice()); + else + getBuilder()->addExportDecoration(alias, mangledName.getUnownedSlice()); + + context->setGlobalValue(inheritanceDecl, LoweredValInfo::simple(alias)); + return LoweredValInfo::simple(alias); + } + // An inheritance clause inside of an `interface` // declaration should not give rise to a witness // table, because it represents something the // interface requires, and not what it provides. // - auto parentDecl = inheritanceDecl->parentDecl; if (const auto parentInterfaceDecl = as<InterfaceDecl>(parentDecl)) { return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); @@ -9724,10 +9760,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> SLANG_UNREACHABLE("associatedtype should have been handled by visitAssocTypeDecl."); } - // TODO(JS): - // Not clear what to do around HLSLExportModifier. - // The HLSL spec says it only applies to functions, so we ignore for now. - // We are going to create nested IR building state // to use when emitting the members of the type. // @@ -9738,6 +9770,35 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Emit any generics that should wrap the actual type. auto outerGeneric = emitOuterGenerics(subContext, decl, decl); + if (decl->aliasedType) + { + // If the type decl is an alias of another type, then we lower it into + // a IRSymbolAlias. + auto loweredType = lowerType(subContext, decl->aliasedType); + if (loweredType) + { + IRInst* alias = nullptr; + if (outerGeneric) + { + alias = finishOuterGenerics(subBuilder, loweredType, outerGeneric); + } + else + { + alias = subBuilder->emitSymbolAlias(loweredType); + } + addLinkageDecoration(subContext, alias, decl); + + // Enumerate all witnesses and lower IRSymbolAlias for them as well. + for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) + { + if (!inheritanceDecl->witnessVal) + continue; + ensureDecl(subContext, inheritanceDecl); + } + return LoweredValInfo::simple(alias); + } + } + IRType* irAggType = nullptr; if (as<StructDecl>(decl)) { diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 605e50d41..3c4c23a42 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2205,386 +2205,413 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( varLayout->flags |= VarLayoutFlag::HasSemantic; } - // Scalar and vector types are treated as outputs directly - if (auto basicType = as<BasicExpressionType>(type)) + // We use a lambda here to process the parameter based on `type`. + // We need to be able to recurse on the lambda if need to translate/resolve + // `type` to something else, in that case we simply call the lambda recursively. + auto processParamOfType = [&](auto&& processParamOfTypeFunc, Type* type) -> RefPtr<TypeLayout> { - return processSimpleEntryPointParameter(context, basicType, state, varLayout); - } - else if (auto vectorType = as<VectorExpressionType>(type)) - { - return processSimpleEntryPointParameter(context, vectorType, state, varLayout); - } - // A matrix is processed as if it was an array of rows - else if (auto matrixType = as<MatrixExpressionType>(type)) - { - auto foldedRowCountVal = - context->getTargetProgram()->getProgram()->tryFoldIntVal(matrixType->getRowCount()); - IntegerLiteralValue rowCount = 0; - if (!foldedRowCountVal) + // Scalar and vector types are treated as outputs directly + if (auto basicType = as<BasicExpressionType>(type)) { - rowCount = getIntVal(foldedRowCountVal); + return processSimpleEntryPointParameter(context, basicType, state, varLayout); } - return processSimpleEntryPointParameter( - context, - matrixType, - state, - varLayout, - (int)rowCount); - } - else if (auto arrayType = as<ArrayExpressionType>(type)) - { - // Note: Bad Things will happen if we have an array input - // without a semantic already being enforced. - UInt elementCount = 0; - - if (!arrayType->isUnsized()) + else if (auto vectorType = as<VectorExpressionType>(type)) { - auto intVal = context->getTargetProgram()->getProgram()->tryFoldIntVal( - arrayType->getElementCount()); - if (intVal) - elementCount = (UInt)getIntVal(intVal); + return processSimpleEntryPointParameter(context, vectorType, state, varLayout); } - - // We use the first element to derive the layout for the element type - auto elementTypeLayout = processEntryPointVaryingParameter( - context, - arrayType->getElementType(), - state, - varLayout); - - // We still walk over subsequent elements to make sure they consume resources - // as needed - for (UInt ii = 1; ii < elementCount; ++ii) + // A matrix is processed as if it was an array of rows + else if (auto matrixType = as<MatrixExpressionType>(type)) { - processEntryPointVaryingParameter(context, arrayType->getElementType(), state, nullptr); + auto foldedRowCountVal = + context->getTargetProgram()->getProgram()->tryFoldIntVal(matrixType->getRowCount()); + IntegerLiteralValue rowCount = 0; + if (!foldedRowCountVal) + { + rowCount = getIntVal(foldedRowCountVal); + } + return processSimpleEntryPointParameter( + context, + matrixType, + state, + varLayout, + (int)rowCount); } - - RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->elementTypeLayout = elementTypeLayout; - arrayTypeLayout->type = arrayType; - - for (auto rr : elementTypeLayout->resourceInfos) + else if (auto arrayType = as<ArrayExpressionType>(type)) { - arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount; - } - - return arrayTypeLayout; - } - else if (auto meshOutputType = as<MeshOutputType>(type)) - { - // TODO: Ellie, revisit - // Note: Bad Things will happen if we have an array input - // without a semantic already being enforced. + // Note: Bad Things will happen if we have an array input + // without a semantic already being enforced. + UInt elementCount = 0; - // We use the first element to derive the layout for the element type - auto elementTypeLayout = processEntryPointVaryingParameter( - context, - meshOutputType->getElementType(), - state, - varLayout); + if (!arrayType->isUnsized()) + { + auto intVal = context->getTargetProgram()->getProgram()->tryFoldIntVal( + arrayType->getElementCount()); + if (intVal) + elementCount = (UInt)getIntVal(intVal); + } - RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->elementTypeLayout = elementTypeLayout; - arrayTypeLayout->type = arrayType; + // We use the first element to derive the layout for the element type + auto elementTypeLayout = processEntryPointVaryingParameter( + context, + arrayType->getElementType(), + state, + varLayout); - // TODO: Ellie, this is probably not the right place to handle this - // On GLSL the indices type is built in and as such doesn't consume - // resources. - if (!isKhronosTarget(context->getTargetRequest()) || !as<IndicesType>(type)) - { - for (auto rr : elementTypeLayout->resourceInfos) + // We still walk over subsequent elements to make sure they consume resources + // as needed + for (UInt ii = 1; ii < elementCount; ++ii) { - // TODO: Ellie, explain why only one slot is consumed here - arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count; + processEntryPointVaryingParameter( + context, + arrayType->getElementType(), + state, + nullptr); } - } - return arrayTypeLayout; - } - else if (auto patchType = as<HLSLPatchType>(type)) - { - // Similar to the MeshOutput case, a `InputPatch` or `OutputPatch` type is just like an - // array. - // - auto elementTypeLayout = processEntryPointVaryingParameter( - context, - patchType->getElementType(), - state, - varLayout); + RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; - RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); - arrayTypeLayout->elementTypeLayout = elementTypeLayout; - arrayTypeLayout->type = arrayType; + for (auto rr : elementTypeLayout->resourceInfos) + { + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount; + } - for (auto rr : elementTypeLayout->resourceInfos) - { - arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count; + return arrayTypeLayout; } - - return arrayTypeLayout; - } - // Ignore a bunch of types that don't make sense here... - else if (const auto subpassType = as<SubpassInputType>(type)) - { - return nullptr; - } - else if (const auto textureType = as<TextureType>(type)) - { - return nullptr; - } - else if (const auto samplerStateType = as<SamplerStateType>(type)) - { - return nullptr; - } - else if (const auto constantBufferType = as<ConstantBufferType>(type)) - { - return nullptr; - } - else if (auto ptrType = as<PtrType>(type)) - { - SLANG_ASSERT(ptrType->astNodeType == ASTNodeType::PtrType); - - auto typeLayout = processSimpleEntryPointParameter(context, ptrType, state, varLayout); - RefPtr<PointerTypeLayout> ptrTypeLayout = typeLayout.as<PointerTypeLayout>(); - - // Work out the layout for the value/target type - auto valueTypeLayout = - processEntryPointVaryingParameter(context, ptrType->getValueType(), state, varLayout); - ptrTypeLayout->valueTypeLayout = valueTypeLayout; - return ptrTypeLayout; - } - else if (auto optionalType = as<OptionalType>(type)) - { - Array<Type*, 2> types = - makeArray(optionalType->getValueType(), context->getASTBuilder()->getBoolType()); - auto tupleType = context->getASTBuilder()->getTupleType(types.getView()); - return processEntryPointVaryingParameter(context, tupleType, state, varLayout); - } - else if (auto tupleType = as<TupleType>(type)) - { - RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); - structLayout->type = type; - for (Index i = 0; i < tupleType->getMemberCount(); i++) + else if (auto meshOutputType = as<MeshOutputType>(type)) { - auto fieldType = tupleType->getMember(i); - RefPtr<VarLayout> fieldVarLayout = new VarLayout(); - - // We don't really have a "field" decl, so just use the tuple-typed decl - // itself as the varDecl of the elements. - auto fieldDecl = (VarDeclBase*)varLayout->varDecl.getDecl(); - fieldVarLayout->varDecl = fieldDecl; + // TODO: Ellie, revisit + // Note: Bad Things will happen if we have an array input + // without a semantic already being enforced. - structLayout->fields.add(fieldVarLayout); - - auto fieldTypeLayout = processEntryPointVaryingParameterDecl( + // We use the first element to derive the layout for the element type + auto elementTypeLayout = processEntryPointVaryingParameter( context, - fieldDecl, - fieldType, + meshOutputType->getElementType(), state, - fieldVarLayout); + varLayout); + + RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; - if (!fieldTypeLayout) + // TODO: Ellie, this is probably not the right place to handle this + // On GLSL the indices type is built in and as such doesn't consume + // resources. + if (!isKhronosTarget(context->getTargetRequest()) || !as<IndicesType>(type)) { - getSink(context)->diagnose( - varLayout->varDecl, - Diagnostics::notValidVaryingParameter, - fieldType); - continue; + for (auto rr : elementTypeLayout->resourceInfos) + { + // TODO: Ellie, explain why only one slot is consumed here + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count; + } } - fieldVarLayout->typeLayout = fieldTypeLayout; - // Assign offsets in var layout for each resource kind of the type. - for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos) + return arrayTypeLayout; + } + else if (auto patchType = as<HLSLPatchType>(type)) + { + // Similar to the MeshOutput case, a `InputPatch` or `OutputPatch` type is just like an + // array. + // + auto elementTypeLayout = processEntryPointVaryingParameter( + context, + patchType->getElementType(), + state, + varLayout); + + RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; + + for (auto rr : elementTypeLayout->resourceInfos) { - auto kind = fieldTypeResInfo.kind; - auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind); - auto fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind); - fieldResInfo->index = structTypeResInfo->count.getFiniteValue(); - structTypeResInfo->count += fieldTypeResInfo.count; + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count; } - } - return structLayout; - } - // Catch declaration-reference types late in the sequence, since - // otherwise they will include all of the above cases... - else if (auto declRefType = as<DeclRefType>(type)) - { - // If we are trying to get the layout of some extern type, do our best - // to look it up in other loaded modules and generate the type layout - // based on that. - declRefType = context->layoutContext.lookupExternDeclRefType(declRefType); - auto declRef = declRefType->getDeclRef(); + return arrayTypeLayout; + } + // Ignore a bunch of types that don't make sense here... + else if (const auto subpassType = as<SubpassInputType>(type)) + { + return nullptr; + } + else if (const auto textureType = as<TextureType>(type)) + { + return nullptr; + } + else if (const auto samplerStateType = as<SamplerStateType>(type)) + { + return nullptr; + } + else if (const auto constantBufferType = as<ConstantBufferType>(type)) + { + return nullptr; + } + else if (auto ptrType = as<PtrType>(type)) + { + SLANG_ASSERT(ptrType->astNodeType == ASTNodeType::PtrType); + auto typeLayout = processSimpleEntryPointParameter(context, ptrType, state, varLayout); + RefPtr<PointerTypeLayout> ptrTypeLayout = typeLayout.as<PointerTypeLayout>(); - if (auto structDeclRef = declRef.as<StructDecl>()) + // Work out the layout for the value/target type + auto valueTypeLayout = processEntryPointVaryingParameter( + context, + ptrType->getValueType(), + state, + varLayout); + ptrTypeLayout->valueTypeLayout = valueTypeLayout; + return ptrTypeLayout; + } + else if (auto optionalType = as<OptionalType>(type)) + { + Array<Type*, 2> types = + makeArray(optionalType->getValueType(), context->getASTBuilder()->getBoolType()); + auto tupleType = context->getASTBuilder()->getTupleType(types.getView()); + return processEntryPointVaryingParameter(context, tupleType, state, varLayout); + } + else if (auto tupleType = as<TupleType>(type)) { RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); - structLayout->type = declRefType; - - // We will recursively walk the fields of a `struct` type - // to compute layouts for those fields. - // - // Along the way, we may find fields with explicit layout - // annotations, along with fields that have no explicit - // layout. We will consider it an error to have a mix of - // the two. - // - // TODO: We could support a mix of implicit and explicit - // layout by performing layout on fields in two passes, - // much like is done for the global scope. This would - // complicate layout significantly for little practical - // benefit, so it is very much a "nice to have" rather - // than a "must have" feature. - // - Decl* firstExplicit = nullptr; - Decl* firstImplicit = nullptr; - for (auto field : - getFields(context->getASTBuilder(), structDeclRef, MemberFilterStyle::Instance)) + structLayout->type = type; + for (Index i = 0; i < tupleType->getMemberCount(); i++) { + auto fieldType = tupleType->getMember(i); RefPtr<VarLayout> fieldVarLayout = new VarLayout(); - fieldVarLayout->varDecl = field; + + // We don't really have a "field" decl, so just use the tuple-typed decl + // itself as the varDecl of the elements. + auto fieldDecl = (VarDeclBase*)varLayout->varDecl.getDecl(); + fieldVarLayout->varDecl = fieldDecl; structLayout->fields.add(fieldVarLayout); - structLayout->mapVarToLayout.add(field.getDecl(), fieldVarLayout); auto fieldTypeLayout = processEntryPointVaryingParameterDecl( context, - field.getDecl(), - getType(context->getASTBuilder(), field), + fieldDecl, + fieldType, state, fieldVarLayout); if (!fieldTypeLayout) { - getSink(context)->diagnose(field, Diagnostics::notValidVaryingParameter, field); + getSink(context)->diagnose( + varLayout->varDecl, + Diagnostics::notValidVaryingParameter, + fieldType); continue; } fieldVarLayout->typeLayout = fieldTypeLayout; - // The field needs to have offset information stored - // in `fieldVarLayout` for every kind of resource - // consumed by `fieldTypeLayout`. - // + // Assign offsets in var layout for each resource kind of the type. for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos) { - // If the field is a Conditional<T, false> type, then it could have 0 size. - // We should skip this field if it has no use of layout units. - if (fieldTypeResInfo.count == 0) - continue; - auto kind = fieldTypeResInfo.kind; - auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind); + auto fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind); + fieldResInfo->index = structTypeResInfo->count.getFiniteValue(); + structTypeResInfo->count += fieldTypeResInfo.count; + } + } + return structLayout; + } + // Catch declaration-reference types late in the sequence, since + // otherwise they will include all of the above cases... + else if (auto declRefType = as<DeclRefType>(type)) + { + // If we are trying to get the layout of some extern type, do our best + // to look it up in other loaded modules and generate the type layout + // based on that. + auto lookedUpType = context->layoutContext.lookupExternDeclRefType(declRefType); + + // If the link-time type resolved to something concrete, process the param as if it is + // of the concrete type by recursively calling this lambda. + if (type != lookedUpType) + return processParamOfTypeFunc(_Move(processParamOfTypeFunc), lookedUpType); + + auto declRef = declRefType->getDeclRef(); + + if (auto structDeclRef = declRef.as<StructDecl>()) + { + RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); + structLayout->type = declRefType; + + // We will recursively walk the fields of a `struct` type + // to compute layouts for those fields. + // + // Along the way, we may find fields with explicit layout + // annotations, along with fields that have no explicit + // layout. We will consider it an error to have a mix of + // the two. + // + // TODO: We could support a mix of implicit and explicit + // layout by performing layout on fields in two passes, + // much like is done for the global scope. This would + // complicate layout significantly for little practical + // benefit, so it is very much a "nice to have" rather + // than a "must have" feature. + // + Decl* firstExplicit = nullptr; + Decl* firstImplicit = nullptr; + for (auto field : getFields( + context->getASTBuilder(), + structDeclRef, + MemberFilterStyle::Instance)) + { + RefPtr<VarLayout> fieldVarLayout = new VarLayout(); + fieldVarLayout->varDecl = field; + + structLayout->fields.add(fieldVarLayout); + structLayout->mapVarToLayout.add(field.getDecl(), fieldVarLayout); + + auto fieldTypeLayout = processEntryPointVaryingParameterDecl( + context, + field.getDecl(), + getType(context->getASTBuilder(), field), + state, + fieldVarLayout); - auto fieldResInfo = fieldVarLayout->FindResourceInfo(kind); - if (!fieldResInfo) + if (!fieldTypeLayout) { - if (!firstImplicit) - firstImplicit = field.getDecl(); - - // In the implicit-layout case, we assign the field - // the next available offset after the fields that - // have preceded it. - // - fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind); - fieldResInfo->index = structTypeResInfo->count.getFiniteValue(); - structTypeResInfo->count += fieldTypeResInfo.count; + getSink(context)->diagnose( + field, + Diagnostics::notValidVaryingParameter, + field); + continue; } - else + fieldVarLayout->typeLayout = fieldTypeLayout; + + // The field needs to have offset information stored + // in `fieldVarLayout` for every kind of resource + // consumed by `fieldTypeLayout`. + // + for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos) { - if (!firstExplicit) - firstExplicit = field.getDecl(); - - // In the explicit case, the field already has offset - // information, and we just need to update the computed - // size of the `struct` type to account for the field. - // - auto fieldEndOffset = fieldResInfo->index + fieldTypeResInfo.count; - structTypeResInfo->count = - maximum(structTypeResInfo->count, fieldEndOffset); + // If the field is a Conditional<T, false> type, then it could have 0 size. + // We should skip this field if it has no use of layout units. + if (fieldTypeResInfo.count == 0) + continue; + + auto kind = fieldTypeResInfo.kind; + + auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind); + + auto fieldResInfo = fieldVarLayout->FindResourceInfo(kind); + if (!fieldResInfo) + { + if (!firstImplicit) + firstImplicit = field.getDecl(); + + // In the implicit-layout case, we assign the field + // the next available offset after the fields that + // have preceded it. + // + fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind); + fieldResInfo->index = structTypeResInfo->count.getFiniteValue(); + structTypeResInfo->count += fieldTypeResInfo.count; + } + else + { + if (!firstExplicit) + firstExplicit = field.getDecl(); + + // In the explicit case, the field already has offset + // information, and we just need to update the computed + // size of the `struct` type to account for the field. + // + auto fieldEndOffset = fieldResInfo->index + fieldTypeResInfo.count; + structTypeResInfo->count = + maximum(structTypeResInfo->count, fieldEndOffset); + } } } + if (firstImplicit && firstExplicit) + { + getSink(context)->diagnose( + firstImplicit, + Diagnostics::mixingImplicitAndExplicitBindingForVaryingParams, + firstImplicit->getName(), + firstExplicit->getName()); + } + + return structLayout; } - if (firstImplicit && firstExplicit) + else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>()) { - getSink(context)->diagnose( - firstImplicit, - Diagnostics::mixingImplicitAndExplicitBindingForVaryingParams, - firstImplicit->getName(), - firstExplicit->getName()); - } - - return structLayout; - } - else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>()) - { - auto& layoutContext = context->layoutContext; + auto& layoutContext = context->layoutContext; - if (auto concreteType = findGlobalGenericSpecializationArg( - layoutContext, - globalGenericParamDecl.getDecl())) + if (auto concreteType = findGlobalGenericSpecializationArg( + layoutContext, + globalGenericParamDecl.getDecl())) + { + // If we know what concrete type has been used to specialize + // the global generic type parameter, then we should use + // the concrete type instead. + // + // Note: it should be illegal for the user to use a generic + // type parameter in a varying parameter list without giving + // it an explicit user-defined semantic. Otherwise, it would be possible + // that the concrete type that gets plugged in is a user-defined + // `struct` that uses some `SV_` semantics in its definition, + // so that any static information about what system values + // the entry point uses would be incorrect. + // + return processEntryPointVaryingParameter( + context, + concreteType, + state, + varLayout); + } + else + { + // If we don't know a concrete type, then we aren't generating final + // code, so the reflection information should show the generic + // type parameter. + // + // We don't make any attempt to assign varying parameter resources + // to the generic type, since we can't know how many "slots" + // of varying input/output it would consume. + // + return createTypeLayoutForGlobalGenericTypeParam( + layoutContext, + type, + globalGenericParamDecl.getDecl()); + } + } + else if (auto enumDeclRef = declRef.as<EnumDecl>()) { - // If we know what concrete type has been used to specialize - // the global generic type parameter, then we should use - // the concrete type instead. + // We handle an enumeration type as its tag type for varying parameters. + // This allows enums to be used in vertex output/input similar to their + // underlying integer types. // - // Note: it should be illegal for the user to use a generic - // type parameter in a varying parameter list without giving - // it an explicit user-defined semantic. Otherwise, it would be possible - // that the concrete type that gets plugged in is a user-defined - // `struct` that uses some `SV_` semantics in its definition, - // so that any static information about what system values - // the entry point uses would be incorrect. - // - return processEntryPointVaryingParameter(context, concreteType, state, varLayout); + auto tagType = enumDeclRef.getDecl()->tagType; + SLANG_ASSERT(tagType); + return processEntryPointVaryingParameter(context, tagType, state, varLayout); + } + else if (auto associatedTypeParam = declRef.as<AssocTypeDecl>()) + { + RefPtr<TypeLayout> assocTypeLayout = new TypeLayout(); + assocTypeLayout->type = type; + return assocTypeLayout; } else { - // If we don't know a concrete type, then we aren't generating final - // code, so the reflection information should show the generic - // type parameter. - // - // We don't make any attempt to assign varying parameter resources - // to the generic type, since we can't know how many "slots" - // of varying input/output it would consume. - // - return createTypeLayoutForGlobalGenericTypeParam( - layoutContext, - type, - globalGenericParamDecl.getDecl()); + SLANG_UNEXPECTED("unhandled type kind"); } } - else if (auto enumDeclRef = declRef.as<EnumDecl>()) - { - // We handle an enumeration type as its tag type for varying parameters. - // This allows enums to be used in vertex output/input similar to their - // underlying integer types. - // - auto tagType = enumDeclRef.getDecl()->tagType; - SLANG_ASSERT(tagType); - return processEntryPointVaryingParameter(context, tagType, state, varLayout); - } - else if (auto associatedTypeParam = declRef.as<AssocTypeDecl>()) - { - RefPtr<TypeLayout> assocTypeLayout = new TypeLayout(); - assocTypeLayout->type = type; - return assocTypeLayout; - } - else + + // If we ran into an error in checking the user's code, then skip this parameter + else if (const auto errorType = as<ErrorType>(type)) { - SLANG_UNEXPECTED("unhandled type kind"); + return nullptr; } - } - // If we ran into an error in checking the user's code, then skip this parameter - else if (const auto errorType = as<ErrorType>(type)) - { - return nullptr; - } - - SLANG_UNEXPECTED("unhandled type kind"); - UNREACHABLE_RETURN(nullptr); + SLANG_UNEXPECTED("unhandled type kind"); + UNREACHABLE_RETURN(nullptr); + }; + return processParamOfType(_Move(processParamOfType), type); } /// Compute the type layout for a parameter declared directly on an entry point. @@ -2606,8 +2633,8 @@ static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( LayoutRulesImpl* layoutRules = nullptr; if (isKhronosTarget(context->getTargetRequest())) { - // For Vulkan, entry point uniform parameters are laid out using push constant buffer - // rules (defaults to std430). + // For Vulkan, entry point uniform parameters are laid out using push constant + // buffer rules (defaults to std430). layoutRules = context->getRulesFamily()->getShaderStorageBufferRules( context->getTargetProgram()->getOptionSet()); } @@ -3559,15 +3586,16 @@ static void collectParameters(ParameterBindingContext* inContext, ComponentType* /// Emit a diagnostic about a uniform/ordinary parameter at global scope. void diagnoseGlobalUniform(SharedParameterBindingContext* sharedContext, VarDeclBase* varDecl) { - // Don't emit the implicit global shader parameter warning if the variable is explicitly marked - // as uniform + // Don't emit the implicit global shader parameter warning if the variable is explicitly + // marked as uniform if (!varDecl->hasModifier<HLSLUniformModifier>()) { getSink(sharedContext) ->diagnose(varDecl, Diagnostics::globalUniformNotExpected, varDecl->getName()); } - // Always check and warn about binding attributes being ignored, regardless of uniform modifier + // Always check and warn about binding attributes being ignored, regardless of uniform + // modifier if (varDecl->findModifier<GLSLBindingAttribute>()) { sharedContext->m_sink->diagnose( @@ -3635,7 +3663,8 @@ struct ParameterBindingVisitorCounters Index globalParamCounter = 0; }; -/// Recursive routine to "complete" all binding for parameters and entry points in `componentType`. +/// Recursive routine to "complete" all binding for parameters and entry points in +/// `componentType`. /// /// This includes allocation of as-yet-unused register/binding ranges to parameters (which /// will then affect the ranges of registers/bindings that are available to subsequent @@ -3973,8 +4002,8 @@ static bool _calcNeedsDefaultSpace(SharedParameterBindingContext& sharedContext) continue; case LayoutResourceKind::Uniform: { - // If it's uniform, but we have globals binding defined, we don't need a default - // space for it as it will go in the global binding specified + // If it's uniform, but we have globals binding defined, we don't need a + // default space for it as it will go in the global binding specified if (auto hlslToVulkanOptions = sharedContext.getTargetProgram()->getHLSLToVulkanLayoutOptions()) { @@ -4070,8 +4099,8 @@ static void _maybeApplyHLSLToVulkanShifts( return; } - // If the user specified -fvk-b-shift for the default space but not -fvk-bind-global, we want to - // apply the shift to the global constant buffer. + // If the user specified -fvk-b-shift for the default space but not -fvk-bind-global, we + // want to apply the shift to the global constant buffer. if (!vulkanOptions->hasGlobalsBinding()) { auto globalCBufferShift = vulkanOptions->getShift( @@ -4117,10 +4146,10 @@ static void _maybeApplyHLSLToVulkanShifts( // In essence we need to look for HLSL kinds which have inferance. // We assume all map to Descriptor, and look for descriptor overlaps - // We know there can't be a clash of HLSL layout kinds previously, otherwise that - // would have already produced an a warning. We also know the only change is either - // *all* of a set is shifted or none. That means post a shift there still can't be - // clash between HLSL types. + // We know there can't be a clash of HLSL layout kinds previously, otherwise + // that would have already produced an a warning. We also know the only change + // is either *all* of a set is shifted or none. That means post a shift there + // still can't be clash between HLSL types. // So clashes can only be between HLSL types and other bindings (regardless) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index dde783ce8..3519b6d43 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5509,7 +5509,7 @@ Decl* Parser::ParseStruct() parseOptionalInheritanceClause(this, rs); if (AdvanceIf(this, TokenType::OpAssign)) { - rs->wrappedType = ParseTypeExp(); + rs->aliasedType = ParseTypeExp(); PushScope(rs); PopScope(); if (!LookAheadToken(TokenType::Semicolon)) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 640fa8fcd..50a319ea6 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -196,6 +196,11 @@ inline Type* getTargetType(ASTBuilder* astBuilder, DeclRef<ExtensionDecl> declRe return declRef.substitute(astBuilder, declRef.getDecl()->targetType.Ptr()); } +inline Type* getAliasedType(ASTBuilder* astBuilder, DeclRef<AggTypeDecl> declRef) +{ + return declRef.substitute(astBuilder, declRef.getDecl()->aliasedType.Ptr()); +} + inline FilteredMemberRefList<VarDecl> getFields( ASTBuilder* astBuilder, DeclRef<StructDecl> declRef, diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 0de6348bf..5bbcd2eb1 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -5191,9 +5191,11 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type // If we are trying to get the layout of some extern type, do our best // to look it up in other loaded modules and generate the type layout // based on that. - declRefType = context.lookupExternDeclRefType(declRefType); - auto declRef = declRefType->getDeclRef(); + auto resolvedType = context.lookupExternDeclRefType(declRefType); + if (resolvedType != type) + return _createTypeLayout(context, resolvedType); + auto declRef = declRefType->getDeclRef(); if (auto structDeclRef = declRef.as<StructDecl>()) { @@ -5888,26 +5890,71 @@ GlobalGenericParamDecl* GenericParamTypeLayout::getGlobalGenericParamDecl() return rsDeclRef.getDecl(); } -DeclRefType* TypeLayoutContext::lookupExternDeclRefType(DeclRefType* declRefType) +// Get the decl ref to the outer generic if the decl referenced by `declRef` is generic. +DeclRef<GenericDecl> getOuterGeneric(DeclRef<Decl> declRef) +{ + if (auto directDeclRef = as<DirectDeclRef>(declRef.declRefBase)) + { + if (as<GenericDecl>(directDeclRef->getDecl())) + return DeclRef<GenericDecl>(directDeclRef); + if (as<GenericDecl>(directDeclRef->getParent()->getDecl())) + return DeclRef<GenericDecl>(directDeclRef->getParent()); + } + else if (auto genAppDeclRef = as<GenericAppDeclRef>(declRef.declRefBase)) + { + return DeclRef<GenericDecl>(genAppDeclRef->getBase()); + } + return DeclRef<GenericDecl>(); +} + +Type* TypeLayoutContext::lookupExternDeclRefType(DeclRefType* declRefType) { const auto declRef = declRefType->getDeclRef(); const auto decl = declRef.getDecl(); const auto isExtern = decl->hasModifier<ExternAttribute>() || decl->hasModifier<ExternModifier>(); + Type* resultType = declRefType; if (isExtern) { if (!externTypeMap) buildExternTypeMap(); const auto mangledName = getMangledName(targetReq->getLinkage()->getASTBuilder(), decl); - externTypeMap->tryGetValue(mangledName, declRefType); + externTypeMap->tryGetValue(mangledName, resultType); + if (auto resolvedDeclRef = isDeclRefTypeOf<Decl>(resultType)) + { + if (resolvedDeclRef != declRef) + { + // If declRef is a GenericApp, we should replace the generic base to + // resolveDeclRef's base. + if (auto originalGenericApp = as<GenericAppDeclRef>(declRef.declRefBase)) + { + if (auto resolvedOuterGeneric = getOuterGeneric(resolvedDeclRef.getDecl())) + { + auto substGenericApp = astBuilder->getGenericAppDeclRef( + resolvedOuterGeneric, + originalGenericApp->getArgs()); + resultType = DeclRefType::create(astBuilder, substGenericApp); + } + } + } + } + } + + // If the type is an alias of another type, then we should create the type layout + // from the aliased type instead. + if (auto aggTypeDeclRef = isDeclRefTypeOf<AggTypeDecl>(resultType)) + { + if (auto aliasedType = as<Type>(getAliasedType(astBuilder, aggTypeDeclRef))) + { + return aliasedType; + } } - return declRefType; + return resultType; } void TypeLayoutContext::buildExternTypeMap() { externTypeMap.emplace(); - const auto linkage = targetReq->getLinkage(); HashSet<String> externNames; Dictionary<String, DeclRefType*> allTypes; @@ -5916,6 +5963,8 @@ void TypeLayoutContext::buildExternTypeMap() // We'll match them up later auto processDecl = [&](auto&& go, Decl* decl) -> void { + if (auto genericDecl = as<GenericDecl>(decl)) + decl = genericDecl->inner; const auto isExtern = decl->hasModifier<ExternAttribute>() || decl->hasModifier<ExternModifier>(); @@ -5933,7 +5982,7 @@ void TypeLayoutContext::buildExternTypeMap() } } - if (auto scopeDecl = as<ScopeDecl>(decl)) + if (auto scopeDecl = isStaticScopeDecl(decl)) { for (auto member : scopeDecl->getDirectMemberDecls()) { @@ -5942,7 +5991,7 @@ void TypeLayoutContext::buildExternTypeMap() } }; - for (const auto& m : linkage->loadedModulesList) + for (const auto& m : programLayout->getProgram()->getModuleDependencies()) { const auto& ast = m->getModuleDecl(); for (auto member : ast->getDirectMemberDecls()) diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index 84840c043..d47981904 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -1183,11 +1183,11 @@ struct TypeLayoutContext // Options passed to object layout ObjectLayoutRulesImpl::Options objectLayoutOptions; - // Mangled names to DeclRefType, this is used to match up 'extern' types to + // Mangled names to Type, this is used to match up 'extern' types to // their linked in definitions during layout generation - std::optional<Dictionary<String, DeclRefType*>> externTypeMap; + std::optional<Dictionary<String, Type*>> externTypeMap; - DeclRefType* lookupExternDeclRefType(DeclRefType* declRefType); + Type* lookupExternDeclRefType(DeclRefType* declRefType); void buildExternTypeMap(); LayoutRulesImpl* getRules() { return rules; } diff --git a/tests/language-feature/modules/wrapper-inout.slang b/tests/language-feature/modules/wrapper-inout.slang index 5b7b9fbce..ec42c9177 100644 --- a/tests/language-feature/modules/wrapper-inout.slang +++ b/tests/language-feature/modules/wrapper-inout.slang @@ -1,10 +1,12 @@ //TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type public interface ITest { + __init(); public int testDir(inout int a); }; public struct TestImpl : ITest { + __init(){} public int testDir(inout int a) { int oldA = a; a = 5; @@ -21,7 +23,7 @@ public struct Test : ITest = TestImpl; [numthreads(1,1,1)] void computeMain() { - Test data; + Test data = {}; int a = 516; int b = data.testDir(a); // CHECK: 5 diff --git a/tools/gfx-unit-test/link-time-type-generic.cpp b/tools/gfx-unit-test/link-time-type-generic.cpp new file mode 100644 index 000000000..ac1750518 --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-generic.cpp @@ -0,0 +1,229 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-rhi.h" +#include "slang-rhi/shader-cursor.h" +#include "unit-test/slang-unit-test.h" + +using namespace rhi; + +// Test that generic link time types conforming to a generic interface with generic +// methods/subscript members work correctly. +// Also test that global generic link-time functions works correctly. + +namespace gfx_test +{ +static Slang::Result loadProgram( + rhi::IDevice* device, + Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection, + bool linkSpecialization = false) +{ + const char* moduleInterfaceSrc = R"( + interface ISimple { float getVal(); } + interface IHasProperty { property float val2{get;set;} } + interface IFoo<T:__BuiltinFloatingPointType> : IHasProperty + { + static const int offset; + [mutating] void setValue(float v); + + T getValue<U:ISimple>(U u); + + __subscript<U:__BuiltinIntegerType>(U index) -> T { get; } + } + struct FooImpl<T:__BuiltinFloatingPointType, int x> : IFoo<T> + { + T val; + static const int offset = x; + [mutating] void setValue(float v) { val = T(v); } + T getValue<U:ISimple>(U u){ return val + T(u.getVal()); } + property float val2 { + get { return __real_cast<float>(val) + 2.0; } + set { val = T(newValue); } + } + __subscript<U:__BuiltinIntegerType>(U index) -> T { get {return T(1.0); } } + }; + struct BarImpl<T:__BuiltinFloatingPointType, int x> : IFoo<T> + { + T val; + static const int offset = -x; + [mutating] void setValue(float v) { val = T(v); } + T getValue<U:ISimple>(U u){ return val - T(1.0); } + property float val2 { + get { return __real_cast<float>(val) + 2.0; } + set { val = T(newValue); } + } + __subscript<U:__BuiltinIntegerType>(U index) -> T { get {return T(2.0); } } + }; + )"; + const char* module0Src = R"( + import ifoo; + extern struct Foo<T:__BuiltinFloatingPointType, int i> : IFoo<T> = FooImpl<T, i+1>; + extern static const float c = 0.0; + extern int linkTimeFunc<int x>() { return x; } + struct SimpleImpl : ISimple + { + float getVal() { return 100.0; } + }; + + // Use an indirect generic function to retrieve val2, to make sure intermediate witness tables + // can be obtained correctly from link-time witnesses. + float getVal2<T:IHasProperty>(T t) { return t.val2; } + + [numthreads(1,1,1)] + void computeMain(uniform RWStructuredBuffer<float> buffer) + { + Foo<float, 0> foo; + foo.setValue(3.0); + buffer[0] = foo.getValue(SimpleImpl()) + getVal2(foo) + Foo<float, 0>.offset + c + foo[0] + linkTimeFunc<0>(); + } + )"; + const char* module1Src = R"( + import ifoo; + export struct Foo<T1:__BuiltinFloatingPointType, int i> : IFoo<T1> = BarImpl<T1, i+1>; + export static const float c = 1.0; + export int linkTimeFunc<int x>() { return x + 1; } + )"; + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + auto moduleInterfaceBlob = + Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc)); + auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src)); + auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src)); + slang::IModule* moduleInterface = + slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob); + slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob); + slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(moduleInterface); + componentTypes.add(module0); + if (linkSpecialization) + componentTypes.add(module1); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + ShaderProgramDesc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createShaderProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +void linkTimeTypeGenericTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create pipeline without linking a specialization override module, so we should + // see the default value of `extern Foo`. + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, false)); + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<IComputePipeline> pipelineState; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef())); + + // Create pipeline with a specialization override module linked in, so we should + // see the result of using `BarImpl<T>` for `extern Foo<T>`. + ComPtr<IShaderProgram> shaderProgram1; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram1, slangReflection, true)); + + ComputePipelineDesc pipelineDesc1 = {}; + pipelineDesc1.program = shaderProgram1.get(); + ComPtr<IComputePipeline> pipelineState1; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc1, pipelineState1.writeRef())); + + const int numberCount = 4; + float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f}; + BufferDesc bufferDesc = {}; + bufferDesc.size = numberCount * sizeof(float); + bufferDesc.format = rhi::Format::Undefined; + bufferDesc.elementSize = sizeof(float); + bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | + BufferUsage::CopyDestination | BufferUsage::CopySource; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBuffer> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + auto queue = device->getQueue(QueueType::Graphics); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{110.0f}); + + // Now run again with the overrided program. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState1); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{10.0f}); +} + +SLANG_UNIT_TEST(linkTimeTypeGenericD3D12) +{ + runTestImpl(linkTimeTypeGenericTestImpl, unitTestContext, DeviceType::D3D12); +} + +SLANG_UNIT_TEST(linkTimeTypeGenerictVulkan) +{ + runTestImpl(linkTimeTypeGenericTestImpl, unitTestContext, DeviceType::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp b/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp new file mode 100644 index 000000000..c640389e5 --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp @@ -0,0 +1,156 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-rhi.h" +#include "slang-rhi/shader-cursor.h" +#include "unit-test/slang-unit-test.h" + +using namespace rhi; + +// Test that a generic type can be used to serve multiple link-time type requirements. + +namespace gfx_test +{ +static Slang::Result loadProgram( + rhi::IDevice* device, + Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection, + bool linkSpecialization = false) +{ + const char* moduleInterfaceSrc = R"( + interface IFoo { int getFoo(); } + interface IBar { int getBar(); } + struct SimpleImpl<int y> : IFoo, IBar + { + int getFoo() { return y; } + int getBar() { return y * 2; } + } + )"; + const char* module0Src = R"( + import ifoo; + extern struct Foo<int x> : IFoo; + extern struct Bar<int x> : IBar; + uniform Foo<10> gFoo; + uniform Bar<20> gBar; + [numthreads(1,1,1)] + void computeMain(uniform RWStructuredBuffer<int> buffer) + { + buffer[0] = gFoo.getFoo() + gBar.getBar(); + } + )"; + const char* module1Src = R"( + import ifoo; + export struct Foo<int x> : IFoo = SimpleImpl<x>; + export struct Bar<int x> : IBar = SimpleImpl<x>; + )"; + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + auto moduleInterfaceBlob = + Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc)); + auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src)); + auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src)); + slang::IModule* moduleInterface = + slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob); + slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob); + slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(moduleInterface); + componentTypes.add(module0); + if (linkSpecialization) + componentTypes.add(module1); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + ShaderProgramDesc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createShaderProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +void linkTimeTypeMultiUseGenericTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl. + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true)); + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<IComputePipeline> pipelineState; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef())); + + const int numberCount = 4; + float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f}; + BufferDesc bufferDesc = {}; + bufferDesc.size = numberCount * sizeof(float); + bufferDesc.format = rhi::Format::Undefined; + bufferDesc.elementSize = sizeof(float); + bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | + BufferUsage::CopyDestination | BufferUsage::CopySource; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBuffer> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + auto queue = device->getQueue(QueueType::Graphics); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{50}); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseGenericD3D12) +{ + runTestImpl(linkTimeTypeMultiUseGenericTestImpl, unitTestContext, DeviceType::D3D12); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseGenericVulkan) +{ + runTestImpl(linkTimeTypeMultiUseGenericTestImpl, unitTestContext, DeviceType::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/gfx-unit-test/link-time-type-multi-use.cpp b/tools/gfx-unit-test/link-time-type-multi-use.cpp new file mode 100644 index 000000000..4dc6d085e --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-multi-use.cpp @@ -0,0 +1,156 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-rhi.h" +#include "slang-rhi/shader-cursor.h" +#include "unit-test/slang-unit-test.h" + +using namespace rhi; + +// Test that a type can be used to serve multiple link-time type requirements. + +namespace gfx_test +{ +static Slang::Result loadProgram( + rhi::IDevice* device, + Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection, + bool linkSpecialization = false) +{ + const char* moduleInterfaceSrc = R"( + interface IFoo { int getFoo(); } + interface IBar { int getBar(); } + struct SimpleImpl : IFoo, IBar + { + int getFoo() { return 10; } + int getBar() { return 20; } + } + )"; + const char* module0Src = R"( + import ifoo; + extern struct Foo : IFoo; + extern struct Bar : IBar; + uniform Foo gFoo; + uniform Bar gBar; + [numthreads(1,1,1)] + void computeMain(uniform RWStructuredBuffer<int> buffer) + { + buffer[0] = gFoo.getFoo() + gBar.getBar(); + } + )"; + const char* module1Src = R"( + import ifoo; + export struct Foo : IFoo = SimpleImpl; + export struct Bar : IBar = SimpleImpl; + )"; + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + auto moduleInterfaceBlob = + Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc)); + auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src)); + auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src)); + slang::IModule* moduleInterface = + slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob); + slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob); + slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(moduleInterface); + componentTypes.add(module0); + if (linkSpecialization) + componentTypes.add(module1); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + ShaderProgramDesc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createShaderProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +void linkTimeTypeMultiUseTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl. + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true)); + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<IComputePipeline> pipelineState; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef())); + + const int numberCount = 4; + float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f}; + BufferDesc bufferDesc = {}; + bufferDesc.size = numberCount * sizeof(float); + bufferDesc.format = rhi::Format::Undefined; + bufferDesc.elementSize = sizeof(float); + bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | + BufferUsage::CopyDestination | BufferUsage::CopySource; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBuffer> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + auto queue = device->getQueue(QueueType::Graphics); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{30}); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseD3D12) +{ + runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::D3D12); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseVulkan) +{ + runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp index c42fd2f16..0bd580c84 100644 --- a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp +++ b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp @@ -27,16 +27,21 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) interface IMaterial { float4 load(); } extern struct Material : IMaterial; ConstantBuffer<Material> gMaterial; - + + interface IFoo { float getVal(); } + struct DefaultFoo : IFoo { float getVal() { return 0.0f; } } + extern struct Foo<T, int x> : IFoo = DefaultFoo; + RWTexture2D tex; extern static const int count; uniform uint4 buffers[count]; + uniform Foo<int4, 1> gFoo; [numthreads(1,1,1)] [shader("compute")] void computeMain() { - tex[uint2(0, 0)] = gMaterial.load(); + tex[uint2(0, 0)] = gMaterial.load() + gFoo.getVal(); } )"; @@ -65,7 +70,8 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) String configModuleSource = "import " + moduleName + ";\n" + R"( export struct Material : IMaterial = MyMaterial; export static const int count = 11; - + struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } } + export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>; struct MyMaterial : IMaterial { int data; Texture2D diffuse; @@ -110,6 +116,9 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) auto var2 = programLayout->getParameterByIndex(2); SLANG_CHECK(var2->getTypeLayout()->getSize() == 11 * 16); + auto var3 = programLayout->getParameterByIndex(3); + SLANG_CHECK(var3->getTypeLayout()->getSize() == 32); + ComPtr<slang::IBlob> codeBlob; linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef()); @@ -226,3 +235,78 @@ SLANG_UNIT_TEST(linkTimeConditionalReflection) SLANG_CHECK(spirvStr.indexOf(toSlice("Location 1")) != -1); SLANG_CHECK(spirvStr.indexOf(toSlice("Location 2")) == -1); } + +// Test that loading a module that defines an `export` type, but not linking with the module should +// not affect the type layout. + +SLANG_UNIT_TEST(linkTimeTypeReflectionWithLoadedButNotLinkedModule) +{ + // Source for a module that contains can be specialized with a link-time type. + const char* userSourceBody = R"( + interface IFoo { float getVal(); } + struct DefaultFoo : IFoo { float getVal() { return 0.0f; } } + extern struct Foo<T, int x> : IFoo = DefaultFoo; + + uniform Foo<int4, 1> gFoo; + RWTexture2D tex; + + [numthreads(1,1,1)] + [shader("compute")] + void computeMain() { + tex[uint2(0, 0)] = gFoo.getVal(); + } + )"; + + String moduleName = "linkTimeTypeReflection_Compute"; + + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV_ASM; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + moduleName.getBuffer(), + (moduleName + ".slang").getBuffer(), + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(module != nullptr); + + // Source for a module that defines the link-time type, but we won't link with it. + String configModuleSource = "import " + moduleName + ";\n" + R"( + struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } } + export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>; + )"; + auto configModule = session->loadModuleFromSourceString( + "config", + "config.slang", + configModuleSource.getBuffer(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(configModule != nullptr); + + ComPtr<slang::IComponentType> linkedProgram; + module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(linkedProgram != nullptr); + + auto programLayout = linkedProgram->getLayout(); + auto var0 = programLayout->getParameterByIndex(0); + + // Size of `gFoo` is 0, because the module that defines `Foo = FooImpl` is not linked. + // Therefore `gFoo`'s type is defaulted to `DefaultFoo`, which has no fields. + SLANG_CHECK(var0->getTypeLayout()->getSize() == 0); + + ComPtr<slang::IBlob> codeBlob; + linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK_ABORT(codeBlob.get()); + + auto spirvStr = UnownedStringSlice((const char*)codeBlob->getBufferPointer()); + + SLANG_CHECK(spirvStr.indexOf(toSlice("OpDecorate %tex Binding 0")) != -1); +} |
