diff options
25 files changed, 668 insertions, 436 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index b6185344b..cce6a545a 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -8354,12 +8354,8 @@ namespace Slang } } - RefPtr<Expr> visitGenericAppExpr(GenericAppExpr * genericAppExpr) + RefPtr<Expr> visitGenericAppExpr(GenericAppExpr* genericAppExpr) { - // We are applying a generic to arguments, but there might be multiple generic - // declarations with the same name, so this becomes a specialized case of - // overload resolution. - // Start by checking the base expression and arguments. auto& baseExpr = genericAppExpr->FunctionExpr; baseExpr = CheckTerm(baseExpr); @@ -8369,6 +8365,19 @@ namespace Slang arg = CheckTerm(arg); } + return checkGenericAppWithCheckedArgs(genericAppExpr); + } + + /// Check a generic application where the operands have already been checked. + RefPtr<Expr> checkGenericAppWithCheckedArgs(GenericAppExpr* genericAppExpr) + { + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. + + auto& baseExpr = genericAppExpr->FunctionExpr; + auto& args = genericAppExpr->Arguments; + // If there was an error in the base expression, or in any of // the arguments, then just bail. if (IsErrorExpr(baseExpr)) @@ -9256,7 +9265,7 @@ namespace Slang // if( entryPoint->getStage() == Stage::Unknown ) { - sink->diagnose(entryPoint->decl, Diagnostics::entryPointHasNoStage, entryPoint->name); + sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::entryPointHasNoStage, entryPoint->name); } if (entryPoint->getStage() == Stage::Hull) @@ -9264,7 +9273,7 @@ namespace Slang auto translationUnit = entryPoint->getTranslationUnit(); auto translationUnitSyntax = translationUnit->SyntaxNode; - auto attr = entryPoint->decl->FindModifier<PatchConstantFuncAttribute>(); + auto attr = entryPoint->getFuncDecl()->FindModifier<PatchConstantFuncAttribute>(); if (attr) { @@ -9306,6 +9315,11 @@ namespace Slang { // The first step in validating the entry point is to find // the (unique) function declaration that matches its name. + // + // TODO: We will eventually need to update this logic + // to work by parsing the provided `entryPoint->name` string + // as an expression, so that we can handle more complex + // names like `foo<int>` or `SomeType.vs`. auto translationUnit = entryPoint->getTranslationUnit(); auto sink = &entryPoint->compileRequest->mSink; @@ -9333,11 +9347,20 @@ namespace Slang // We'll walk the linked list of declarations with the same name, // to see what we find. Along the way we'll keep track of the // first function declaration we find, if any: + // FuncDecl* entryPointFuncDecl = nullptr; for(auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName) { + // We want to support the case where the declaration is + // a generic function, so we will automatically + // unwrap any outer `GenericDecl` we find here. + // + auto decl = ee; + if(auto genericDecl = as<GenericDecl>(decl)) + decl = genericDecl->inner; + // Is this declaration a function? - if (auto funcDecl = as<FuncDecl>(ee)) + if (auto funcDecl = as<FuncDecl>(decl)) { // Skip non-primary declarations, so that // we don't give an error when an entry @@ -9409,34 +9432,65 @@ namespace Slang // Phew, we have at least found a suitable decl. // Let's record that in the entry-point request so // that we don't have to re-do this effort again later. - entryPoint->decl = entryPointFuncDecl; + // + // Note: we may replace the decl-ref we store at this point + // later in this function, when we (potentially) specialize + // a generic entry point to generic arguments provided + // via the API. + // + entryPoint->funcDeclRef = makeDeclRef(entryPointFuncDecl); - // Lookup generic parameter types in global scope + // If the user specified generic arguments for the entry point, + // then we will want to parse those arguments as expressions + // in a scope that includes the tanslation unit that holds + // the entry point, as well as any other modules that got + // transitively loaded via `import`. + // + // TODO: This would be better handled by giving the user + // more explicit ways to parse/build types at the API level, + // rather than keeping things string-based this far along. + // + // TODO: Building a list of `scopesToTry` here shouldn't + // be required, since the `Scope` type itself has the ability + // for form chains for lookup purposes (e.g., the way that + // `import` is handled by modifying a scope). + // List<RefPtr<Scope>> scopesToTry; scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); for (auto & module : entryPoint->compileRequest->loadedModulesList) scopesToTry.Add(module->moduleDecl->scope); - List<RefPtr<Type>> globalGenericArgs; + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + SemanticsVisitor semantics( + &entryPoint->compileRequest->mSink, + entryPoint->compileRequest, + entryPoint->getTranslationUnit()); + + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + List<RefPtr<Expr>> genericArgs; for (auto name : entryPoint->genericArgStrings) { - // parse type name - RefPtr<Type> type; + RefPtr<Expr> argExpr; for (auto & s : scopesToTry) { - RefPtr<Expr> typeExpr = entryPoint->compileRequest->parseTypeString(entryPoint->getTranslationUnit(), - name, s); - type = checkProperType(translationUnit, TypeExp(typeExpr)); - if (type) + argExpr = entryPoint->compileRequest->parseTypeString( + entryPoint->getTranslationUnit(), + name, + s); + argExpr = semantics.CheckTerm(argExpr); + if( argExpr ) { break; } } - if (!type) - { - sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name); - return; - } // The following is a bit of a hack. // @@ -9452,165 +9506,257 @@ namespace Slang // as a (top-level) argument for a generic type parameter, so that we // can check for them here and cache them on the entry point request. // - if( auto taggedUnionType = as<TaggedUnionType>(type) ) + if( auto typeType = as<TypeType>(argExpr->type) ) { - entryPoint->taggedUnionTypes.Add(taggedUnionType); + auto type = typeType->type; + if( auto taggedUnionType = as<TaggedUnionType>(type) ) + { + entryPoint->taggedUnionTypes.Add(taggedUnionType); + } } - globalGenericArgs.Add(type); + genericArgs.Add(argExpr); } - // validate global type arguments only when we are generating code - if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + // There are two cases we care about here, and we are going to treat them + // as mutually exclusive for simplicity. + // + // The first case is when the entry point function is itself generic, + // in which case we will assume that `genericArgs` lines up one-to-one + // with the explicit generic parameters of the entry point. + // + if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) ) { - // check that user-provided type arguments conforms to the generic type - // parameter declaration of this translation unit + // We will construct a suitable `GenericAppExpr` to represent + // the user-specified `genericDecl` being applied to the + // supplied `genericArgs`, and then use the existing + // semantic checking logic that would apply to an explicit + // generic application like `F<A,B,C>` if it were + // encountered in the source code. - // collect global generic parameters from all imported modules - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - // add current translation unit first - { - auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - // add imported modules - for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) - { - auto moduleDecl = loadedModule->moduleDecl; - auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } + auto session = entryPoint->compileRequest->mSession; + auto genericDeclRef = makeDeclRef(genericDecl); - if (globalGenericParams.Count() != globalGenericArgs.Count()) - { - sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, - globalGenericParams.Count(), - globalGenericArgs.Count()); - return; - } - - // We have an appropriate number of arguments for the global generic parameters, - // and now we need to check that the arguments conform to the declared constraints. + // The first pieces is a `VarExpr` that refers to `genericDecl`. // - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. + // TODO: This would not be needed if we instead parsed + // the supplied entry-point name into an expression + // earlier in this function. // - RefPtr<Substitutions> globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: + RefPtr<VarExpr> genericExpr = new VarExpr(); + genericExpr->declRef = genericDeclRef; + genericExpr->type.type = getTypeForDeclRef(session, genericDeclRef); + + // Next we construct the actual `GenericAppExpr` // - // type_param A; - // type_param B : ISidekick<A>; + RefPtr<GenericAppExpr> genericAppExpr = new GenericAppExpr(); + genericAppExpr->FunctionExpr = genericExpr; + genericAppExpr->Arguments = genericArgs; + + // We use the semantics visitor to perform the + // actual checking logic (this might report + // errors) // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor<B>`). + auto checkedExpr = semantics.checkGenericAppWithCheckedArgs(genericAppExpr); + + // Now we need to extract an appropriate decl-ref for the entry + // point from the `checkedExpr`. // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. + if( auto declRefExpr = checkedExpr.as<DeclRefExpr>() ) + { + // TODO: We should eventually check for the case + // where we have a `MemberExpr` or another case of + // `DeclRefExpr` that cannot be summarized as just + // its decl-ref. + // + // The basic `VarExpr` and `StaticMemberExpr` cases + // should be allow-able. + + entryPoint->funcDeclRef = declRefExpr->declRef.as<FuncDecl>(); + } + else if( semantics.IsErrorExpr(checkedExpr) ) + { + // Any semantic error that occured should have been + // reported already. + } + else + { + // The result of specializing a reference to a generic + // function should always be a `DeclRefExpr` + // + SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); + } + } + else + { + // The other case is when the entry point function is *not* itself + // generic, so we assume that any generic arguments must have been intended + // to match up with global generic parameters instead. // - // We will punt on this for now, and just check each constraint in isolation. + // We will only validate global generic type arguments when we are going + // to generate code, since in a no-codegen pass we will typically *not* + // have arguments to associate with the parameters. // - UInt argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) + if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) { - // Get the argument that matches this parameter. - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < globalGenericArgs.Count()); - auto globalGenericArg = globalGenericArgs[argIndex]; + // check that user-provioded type arguments conforms to the generic type + // parameter declaration of this translation unit - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = as<DeclRefType>(globalGenericArg) ) + // collect global generic parameters from all imported modules + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + // add current translation unit first + { + auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + // add imported modules + for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) + { + auto moduleDecl = loadedModule->moduleDecl; + auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + + if (globalGenericParams.Count() != genericArgs.Count()) { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) + sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::mismatchEntryPointTypeArgument, + globalGenericParams.Count(), + genericArgs.Count()); + return; + } + + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. + // + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. + // + RefPtr<Substitutions> globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; + // + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: + // + // type_param A; + // type_param B : ISidekick<A>; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor<B>`). + // + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + UInt argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) + { + // Get the argument that matches this parameter. + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < genericArgs.Count()); + auto globalGenericArg = checkProperType(translationUnit, TypeExp(genericArgs[argIndex])); + if (!globalGenericArg) { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - else + sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, entryPoint->genericArgStrings[argIndex]); + return; + } + + // As a quick sanity check, see if the argument that is being supplied for a parameter + // is just the parameter itself, because this should always be an error: + // + if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) + { + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; + if(argGenericParamDeclRef.getDecl() == globalGenericParam) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToItself, + globalGenericParam->getName()); + sink->diagnose(entryPointFuncDecl, + Diagnostics::noteWhenCompilingEntryPoint, + entryPointFuncDecl->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + globalGenericParam->getName(), + argGenericParamDeclRef.GetName()); + sink->diagnose(entryPointFuncDecl, + Diagnostics::noteWhenCompilingEntryPoint, + entryPointFuncDecl->getName()); + continue; + } } } - } + // Create a substitution for this parameter/argument. + RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; - // Create a substitution for this parameter/argument. - RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) + { + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + // Use our semantic-checking logic to search for a witness to the required conformance + SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); + if (!witness) + { + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, + interfaceType); + } - // Use our semantic-checking logic to search for a witness to the required conformance - SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.Add(constraintArg); } - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.Add(constraintArg); - } + // Add the substitution for this parameter to the global substitution + // set that we are building. - // Add the substitution for this parameter to the global substitution - // set that we are building. + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; + } - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; + entryPoint->globalGenericSubst = globalGenericSubsts; } - - entryPoint->globalGenericSubst = globalGenericSubsts; } + + // If any errors occured while we were checking the generic arguments + // of the entry point, then we should bail out rather than try to + // perform the next step of validation. + // if (sink->errorCount != 0) return; // Now that we've *found* the entry point, it is time to validate // that it actually meets the constraints for the chosen stage/profile. // - // TODO: This validation should be performed "under" any global generic + // TODO: This validation should (probably?) be performed "under" any global generic // parameter substitution we might have created, so that we can validate // based on knowledge of actual types. // @@ -9707,7 +9853,7 @@ namespace Slang RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest(); entryPointReq->compileRequest = compileRequest; entryPointReq->translationUnitIndex = int(tt); - entryPointReq->decl = funcDecl; + entryPointReq->funcDeclRef = makeDeclRef(funcDecl); entryPointReq->name = funcDecl->getName(); entryPointReq->profile = profile; diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 29f7a95d9..c93f3f70c 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -127,6 +127,17 @@ namespace Slang return compileRequest->translationUnits[translationUnitIndex].Ptr(); } + DeclRef<FuncDecl> EntryPointRequest::getFuncDeclRef() + { + return funcDeclRef; + } + + RefPtr<FuncDecl> EntryPointRequest::getFuncDecl() + { + return getFuncDeclRef().getDecl(); + } + + // Profile Profile::LookUp(char const* name) diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 9e8f146b9..39199a62f 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -152,7 +152,11 @@ namespace Slang // This will be filled in as part of semantic analysis; // it should not be assumed to be available in cases // where any errors were diagnosed. - RefPtr<FuncDecl> decl; + // + DeclRef<FuncDecl> funcDeclRef; + + DeclRef<FuncDecl> getFuncDeclRef(); + RefPtr<FuncDecl> getFuncDecl(); RefPtr<Substitutions> globalGenericSubst; diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index c0d31365e..0480eb934 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -175,7 +175,7 @@ SIMPLE_SYNTAX_CLASS(TypeAliasDecl, TypeDefDecl) SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl) END_SYNTAX_CLASS() -// A '__generic_param' declaration, which defines a generic +// A 'type_param' declaration, which defines a generic // entry-point parameter. Is a container of GenericTypeConstraintDecl SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl) END_SYNTAX_CLASS() diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 68790d4ab..9a8446a03 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -274,7 +274,7 @@ DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' // 303xx: interfaces and associated types DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") -DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.") +DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") DIAGNOSTIC(39999, Fatal, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.") diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index d48390d31..d551ba1b9 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -6623,26 +6623,25 @@ String emitEntryPoint( EmitVisitor visitor(&context); - // We are going to create a fresh IR module that we will use to - // clone any code needed by the user's entry point. - IRSpecializationState* irSpecializationState = createIRSpecializationState( - entryPoint, - programLayout, - target, - targetRequest); - { - IRModule* irModule = getIRModule(irSpecializationState); + { auto compileRequest = translationUnit->compileRequest; auto session = compileRequest->mSession; - TypeLegalizationContext typeLegalizationContext; - initialize(&typeLegalizationContext, - session, - irModule); - - auto irEntryPoint = specializeIRForEntryPoint( - irSpecializationState, - entryPoint); + // We start out by performing "linking" at the level of the IR. + // This step will create a fresh IR module to be used for + // code generation, and will copy in any IR definitions that + // the desired entry point requires. Along the way it will + // resolve references to imported/exported symbols across + // modules, and also select between the definitions of + // any "profile-overloaded" symbols. + // + auto linkedIR = linkIR( + entryPoint, + programLayout, + target, + targetRequest); + auto irModule = linkedIR.module; + auto irEntryPoint = linkedIR.entryPoint; #if 0 dumpIRIfEnabled(compileRequest, irModule, "CLONED"); @@ -6708,9 +6707,22 @@ String emitEntryPoint( // we need to ensure that the code only uses types // that are legal on the chosen target. // - legalizeTypes( - &typeLegalizationContext, - irModule); + { + // TODO: The presence of `TypeLegalizationContext` + // in the public API of the `legalizeTypes` function + // is a throwback to when there was AST-level + // type legalization and all the complications it + // created. The pass should be refactored to not + // expose these details. + // + TypeLegalizationContext typeLegalizationContext; + initialize(&typeLegalizationContext, + session, + irModule); + legalizeTypes( + &typeLegalizationContext, + irModule); + } // Debugging output of legalization #if 0 @@ -6810,7 +6822,6 @@ String emitEntryPoint( // GlobalGenericParamSubstitution implementation may reference ir objects targetRequest->compileRequest->compiledModules.Add(irModule); } - destroyIRSpecializationState(irSpecializationState); // Deal with cases where a particular stage requires certain GLSL versions // and/or extensions. diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index dba4fc2d1..35e0f46b8 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -659,13 +659,29 @@ void cloneFunctionCommon( } } +// We will forward-declare the subroutine for eagerly specializing +// an IR-level generic to argument values, because `specializeIRForEntryPoint` +// needs to perform this operation even though it is logically part of +// the later generic specialization pass. +// +IRInst* specializeGeneric( + IRSpecialize* specializeInst); + IRFunc* specializeIRForEntryPoint( IRSpecContext* context, EntryPointRequest* entryPointRequest, EntryPointLayout* entryPointLayout) { - // Look up the IR symbol by name - auto mangledName = getMangledName(entryPointRequest->decl); + // We start by looking up the IR symbol that + // matches the mangled name given to the + // function we want to emit. + // + // Note: the function decl-ref may refer to + // a specialization of a generic function, + // so that the mangled name of the decl-ref is + // not the same as the mangled name of the decl. + // + auto mangledName = getMangledName(entryPointRequest->getFuncDeclRef()); RefPtr<IRSpecSymbol> sym; if (!context->getSymbols().TryGetValue(mangledName, sym)) { @@ -674,40 +690,68 @@ IRFunc* specializeIRForEntryPoint( } // TODO: deal with the case where we might - // have multiple versions... + // have multiple (profile-overloaded) versions... + // + auto originalVal = sym->irGlobalValue; + + // We will start by cloning the entry point reference + // like any other global value. + // + auto clonedVal = cloneGlobalValue(context, originalVal); - auto globalValue = sym->irGlobalValue; - if (globalValue->op != kIROp_Func) + // In the case where the user is requesting a specialization + // of a generic entry point, we have a bit of a problem. + // + // This function is expected to return an `IRFunc` and + // subsequent passes expect to find, e.g., layout information + // attached to the parameters of such a func. + // + // In the generic case, the `clonedValue` won't be an + // `IRFunc`, but instead an `IRSpecialize`. + // + if(auto clonedSpec = as<IRSpecialize>(clonedVal)) { - SLANG_UNEXPECTED("expected an IR function"); + // The Right Thing to do here is to perform some + // amount of generic specialization, at least + // until we get back an `IRFunc`. + // + // The dangerous thing is that the generic specialization + // pass can, in principle, change the signature of + // functions, so that attaching parameter layout + // information *after* specialization might not work. + // + // The compromise we make here is to directly + // invoke the logic for specializing a generic. + // + // In theory this isn't valid, because there is no + // way we can register the specialized function we + // create so that it would be re-used by other instantiations + // with the same arguments (because we cannot be + // sure the generic arguments are themselves fully specialized) + // + // In practice this isn't really a problem, because + // we don't want to share the definition between + // an entry point and an ordinary function anyway. + // + clonedVal = specializeGeneric(clonedSpec); + } + + auto clonedFunc = as<IRFunc>(clonedVal); + if(!clonedFunc) + { + SLANG_UNEXPECTED("expected entry point to be a function"); return nullptr; } - auto originalFunc = (IRFunc*)globalValue; - - // Create a clone for the IR function - auto clonedFunc = context->builder->createFunc(); - - // Note: we do *not* register this cloned declaration - // as the cloned value for the original symbol. - // This is kind of a kludge, but it ensures that - // in the unlikely case that the function is both - // used as an entry point and a callable function - // (yes, this would imply recursion...) we actually - // have two copies, which lets us arbitrarily - // transform the entry point to meet target requirements. - // - // TODO: The above statement is kind of bunk, though, - // because both versions of the function would have - // the same mangled name... :( - // We need to clone all the properties of the original - // function, including any blocks, their parameters, - // and their instructions. - cloneFunctionCommon(context, clonedFunc, originalFunc); + if( !clonedFunc->findDecorationImpl(kIROp_EntryPointDecoration) ) + { + context->builder->addEntryPointDecoration(clonedFunc); + } // We need to attach the layout information for // the entry point to this declaration, so that // we can use it to inform downstream code emit. + // context->builder->addLayoutDecoration( clonedFunc, entryPointLayout); @@ -1166,13 +1210,14 @@ struct IRSpecializationState } }; -IRSpecializationState* createIRSpecializationState( +LinkedIR linkIR( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, CodeGenTarget target, TargetRequest* targetReq) { - IRSpecializationState* state = new IRSpecializationState(); + IRSpecializationState stateStorage; + auto state = &stateStorage; state->programLayout = programLayout; state->target = target; @@ -1232,6 +1277,8 @@ IRSpecializationState* createIRSpecializationState( context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } + context->builder->setInsertInto(context->getModule()->getModuleInst()); + // for now, clone all unreferenced witness tables // // TODO: This step should *not* be needed with the current IR @@ -1242,39 +1289,9 @@ IRSpecializationState* createIRSpecializationState( if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } - return state; -} - -void destroyIRSpecializationState(IRSpecializationState* state) -{ - delete state; -} - -IRModule* getIRModule(IRSpecializationState* state) -{ - return state->irModule; -} - -IRFunc* specializeIRForEntryPoint( - IRSpecializationState* state, - EntryPointRequest* entryPointRequest) -{ - auto translationUnit = entryPointRequest->getTranslationUnit(); - auto originalIRModule = translationUnit->irModule; - if (!originalIRModule) - { - // We should already have emitted IR for the original - // translation unit, and it we don't have it, then - // we are now in trouble. - return nullptr; - } - - auto context = state->getContext(); - auto newProgramLayout = state->newProgramLayout; auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); - // Next, we make sure to clone the global value for // the entry point function itself, and rely on // this step to recursively copy over anything else @@ -1317,13 +1334,19 @@ IRFunc* specializeIRForEntryPoint( context->builder->addLayoutDecoration(clonedType, taggedUnionTypeLayout); } - - // TODO: *technically* we should consider the case where // we have global variables with initializers, since // these should get run whether or not the entry point // references them. - return irEntryPoint; + + // Now that we've cloned the entry point and everything + // it refers to, we can package up the data we return + // to the caller. + // + LinkedIR linkedIR; + linkedIR.module = state->irModule; + linkedIR.entryPoint = irEntryPoint; + return linkedIR; } diff --git a/source/slang/ir-link.h b/source/slang/ir-link.h index 18940f825..4fcdb4618 100644 --- a/source/slang/ir-link.h +++ b/source/slang/ir-link.h @@ -5,29 +5,22 @@ namespace Slang { - // Interface to IR specialization for use when cloning target-specific - // IR as part of compiling an entry point. + struct LinkedIR + { + RefPtr<IRModule> module; + IRFunc* entryPoint; + }; - // `IRSpecializationState` is used as an opaque type to wrap up all - // the data needed to perform IR specialization, without exposing - // implementation details. - struct IRSpecializationState; - IRSpecializationState* createIRSpecializationState( - EntryPointRequest* entryPointRequest, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq); - void destroyIRSpecializationState(IRSpecializationState* state); - IRModule* getIRModule(IRSpecializationState* state); - - struct ExtensionUsageTracker; // Clone the IR values reachable from the given entry point // into the IR module associated with the specialization state. // When multiple definitions of a symbol are found, the one // that is best specialized for the given `targetReq` will be // used. - IRFunc* specializeIRForEntryPoint( - IRSpecializationState* state, - EntryPointRequest* entryPointRequest); + // + LinkedIR linkIR( + EntryPointRequest* entryPointRequest, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq); } diff --git a/source/slang/ir-specialize.cpp b/source/slang/ir-specialize.cpp index 144cc008f..0da06580f 100644 --- a/source/slang/ir-specialize.cpp +++ b/source/slang/ir-specialize.cpp @@ -29,6 +29,14 @@ namespace Slang // simplifications/specializations of one category can open // up opportunities for transformations in the other categories. +struct SpecializationContext; + +IRInst* specializeGenericImpl( + IRGeneric* genericVal, + IRSpecialize* specializeInst, + IRModule* module, + SpecializationContext* context); + struct SpecializationContext { // For convenience, we will keep a pointer to the module @@ -203,112 +211,26 @@ struct SpecializationContext // If no existing specialization is found, we need // to create the specialization instead. + // This mostly amounts to evaluating the generic as + // if it were a function being called. // - // Effectively this amounts to "calling" the generic - // on its concrete argument values and computing the - // result it returns. - // - // For now, all of our generics consist of a single - // basic block, so we can "call" them just by - // cloning the instructions in their single block - // into the global scope, using an environment for - // cloning that maps the generic parameters to - // the concrete arguments that were provided - // by the `specialize(...)` instruction. - // - IRCloneEnv env; - - // We will walk through the parameters of the generic and - // register the corresponding argument of the `specialize` - // instruction to be used as the "cloned" value for each - // parameter. - // - // Suppose we are looking at `specialize(g, a, b, c)` and `g` has - // three generic parameters: `T`, `U`, and `V`. Then we will - // be initializing our environment to map `T -> a`, `U -> b`, - // and `V -> c`. + // We will use a free function to do the actual work + // of evaluating the generic, so that the logic + // can be re-used in other cases that need to + // do one-off specialization. // - UInt argCounter = 0; - for( auto param : genericVal->getParams() ) - { - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < specializeInst->getArgCount()); - - IRInst* arg = specializeInst->getArg(argIndex); - - env.mapOldValToNew.Add(param, arg); - } + IRInst* specializedVal = specializeGenericImpl(genericVal, specializeInst, module, this); - // We will set up an IR builder for insertion - // into the global scope, at the same location - // as the original generic. - // - IRBuilder builderStorage; - IRBuilder* builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(genericVal); - // Now we will run through the body of the generic and - // clone each of its instructions into the global scope, - // until we reach a `return` instruction. + // The value that was returned from evaluating + // the generic is the specialized value, and we + // need to remember it in our dictionary of + // specializations so that we don't instantiate + // this generic again for the same arguments. // - for( auto bb : genericVal->getBlocks() ) - { - // We expect a generic to only ever contain a single block. - // - SLANG_ASSERT(bb == genericVal->getFirstBlock()); - - // We will iterate over the non-parameter ("ordinary") - // instructions only, because parameters were dealt - // with explictly at an earlier point. - // - for( auto ii : bb->getOrdinaryInsts() ) - { - // The last block of the generic is expected to end with - // a `return` instruction for the specialized value that - // comes out of the abstraction. - // - // We thus use that cloned value as the result of the - // specialization step. - // - if( auto returnValInst = as<IRReturnVal>(ii) ) - { - auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); - - // The value that was returned from evaluating - // the generic is the specialized value, and we - // need to remember it in our dictionary of - // specializations so that we don't instantiate - // this generic again for the same arguments. - // - genericSpecializations.Add(key, specializedVal); - - return specializedVal; - } - - // For any instruction other than a `return`, we will - // simply clone it completely into the global scope. - // - IRInst* clonedInst = cloneInst(&env, builder, ii); - - // Any new instructions we create during cloning were - // not present when we initially built our work list, - // so we need to make sure to consider them now. - // - // This is important for the cases where one generic - // invokes another, because there will be `specialize` - // operations nested inside the first generic that refer - // to the second. - // - addToWorkList(clonedInst); - } - } + genericSpecializations.Add(key, specializedVal); - // If we reach this point, something went wrong, because we - // never encountered a `return` inside the body of the generic. - // - SLANG_UNEXPECTED("no return from generic"); - UNREACHABLE_RETURN(nullptr); + return specializedVal; } // The logic for generating a specialization of an IR generic @@ -1302,4 +1224,131 @@ void specializeModule( context.processModule(); } + +IRInst* specializeGenericImpl( + IRGeneric* genericVal, + IRSpecialize* specializeInst, + IRModule* module, + SpecializationContext* context) +{ + // Effectively, specializing a generic amounts to "calling" the generic + // on its concrete argument values and computing the + // result it returns. + // + // For now, all of our generics consist of a single + // basic block, so we can "call" them just by + // cloning the instructions in their single block + // into the global scope, using an environment for + // cloning that maps the generic parameters to + // the concrete arguments that were provided + // by the `specialize(...)` instruction. + // + IRCloneEnv env; + + // We will walk through the parameters of the generic and + // register the corresponding argument of the `specialize` + // instruction to be used as the "cloned" value for each + // parameter. + // + // Suppose we are looking at `specialize(g, a, b, c)` and `g` has + // three generic parameters: `T`, `U`, and `V`. Then we will + // be initializing our environment to map `T -> a`, `U -> b`, + // and `V -> c`. + // + UInt argCounter = 0; + for( auto param : genericVal->getParams() ) + { + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < specializeInst->getArgCount()); + + IRInst* arg = specializeInst->getArg(argIndex); + + env.mapOldValToNew.Add(param, arg); + } + + // We will set up an IR builder for insertion + // into the global scope, at the same location + // as the original generic. + // + SharedIRBuilder sharedBuilderStorage; + sharedBuilderStorage.module = module; + sharedBuilderStorage.session = module->getSession(); + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + builder->setInsertBefore(genericVal); + + // Now we will run through the body of the generic and + // clone each of its instructions into the global scope, + // until we reach a `return` instruction. + // + for( auto bb : genericVal->getBlocks() ) + { + // We expect a generic to only ever contain a single block. + // + SLANG_ASSERT(bb == genericVal->getFirstBlock()); + + // We will iterate over the non-parameter ("ordinary") + // instructions only, because parameters were dealt + // with explictly at an earlier point. + // + for( auto ii : bb->getOrdinaryInsts() ) + { + // The last block of the generic is expected to end with + // a `return` instruction for the specialized value that + // comes out of the abstraction. + // + // We thus use that cloned value as the result of the + // specialization step. + // + if( auto returnValInst = as<IRReturnVal>(ii) ) + { + auto specializedVal = findCloneForOperand(&env, returnValInst->getVal()); + return specializedVal; + } + + // For any instruction other than a `return`, we will + // simply clone it completely into the global scope. + // + IRInst* clonedInst = cloneInst(&env, builder, ii); + + // Any new instructions we create during cloning were + // not present when we initially built our work list, + // so we need to make sure to consider them now. + // + // This is important for the cases where one generic + // invokes another, because there will be `specialize` + // operations nested inside the first generic that refer + // to the second. + // + if( context ) + { + context->addToWorkList(clonedInst); + } + } + } + + // If we reach this point, something went wrong, because we + // never encountered a `return` inside the body of the generic. + // + SLANG_UNEXPECTED("no return from generic"); + UNREACHABLE_RETURN(nullptr); +} + +IRInst* specializeGeneric( + IRSpecialize* specializeInst) +{ + auto baseGeneric = as<IRGeneric>(specializeInst->getBase()); + SLANG_ASSERT(baseGeneric); + if(!baseGeneric) return specializeInst; + + auto module = specializeInst->getModule(); + SLANG_ASSERT(module); + if(!module) return specializeInst; + + return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); +} + + } // namespace Slang diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 3a5e0cd2e..665904969 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -6113,25 +6113,31 @@ static void lowerEntryPointToIR( EntryPointRequest* entryPointRequest) { // First, lower the entry point like an ordinary function - auto entryPointFuncDecl = entryPointRequest->decl; - if (!entryPointFuncDecl) - { - // Something must have gone wrong earlier, if we - // weren't able to associate a declaration with - // the entry point request. - return; - } - auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); + + auto session = context->getSession(); + auto entryPointFuncDeclRef = entryPointRequest->getFuncDeclRef(); + auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); + + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); + + auto loweredEntryPointFunc = getSimpleVal(context, + emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); // Attach a marker decoration so that we recognize // this as an entry point. - auto builder = context->irBuilder; - builder->addEntryPointDecoration(getSimpleVal(context, loweredEntryPointFunc)); + // + builder->addEntryPointDecoration(loweredEntryPointFunc); + + // + if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) + { + builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); + } // Now lower all the arguments supplied for global generic // type parameters. // - builder->setInsertInto(builder->getModule()->getModuleInst()); for (RefPtr<Substitutions> subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer) { auto gSubst = subst.as<GlobalGenericParamSubstitution>(); diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 2fb571db2..583d4fe54 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -1790,7 +1790,7 @@ static void collectGlobalScopeParameters( // First enumerate parameters at global scope // We collect two things here: // 1. A shader parameter, which is always a variable - // 2. A global entry-point generic parameter type (`__generic_param`), + // 2. A global entry-point generic parameter type (`type_param`), // which is a GlobalGenericParamDecl // We collect global generic type parameters in the first pass, // So we can fill in the correct index into ordinary type layouts @@ -2255,13 +2255,13 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( ParameterBindingContext* context, SubstitutionSet typeSubst, - RefPtr<ParamDecl> paramDecl, + DeclRef<ParamDecl> paramDeclRef, RefPtr<VarLayout> paramVarLayout, EntryPointParameterState& state) { - auto paramType = paramDecl->type.type->Substitute(typeSubst).as<Type>(); + auto paramType = GetType(paramDeclRef)->Substitute(typeSubst).as<Type>(); - if( paramDecl->HasModifier<HLSLUniformModifier>() ) + if( paramDeclRef.getDecl()->HasModifier<HLSLUniformModifier>() ) { // An entry-point parameter that is explicitly marked `uniform` represents // a uniform shader parameter passed via the implicitly-defined @@ -2283,21 +2283,24 @@ static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( state.directionMask = 0; // If it appears to be an input, process it as such. - if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) + if( paramDeclRef.getDecl()->HasModifier<InModifier>() + || paramDeclRef.getDecl()->HasModifier<InOutModifier>() + || !paramDeclRef.getDecl()->HasModifier<OutModifier>() ) { state.directionMask |= kEntryPointParameterDirection_Input; } // If it appears to be an output, process it as such. - if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) + if(paramDeclRef.getDecl()->HasModifier<OutModifier>() + || paramDeclRef.getDecl()->HasModifier<InOutModifier>()) { state.directionMask |= kEntryPointParameterDirection_Output; } return processEntryPointVaryingParameterDecl( context, - paramDecl.Ptr(), - paramDecl->type.type->Substitute(typeSubst).as<Type>(), + paramDeclRef.getDecl(), + paramType, state, paramVarLayout); } @@ -2460,22 +2463,14 @@ static void collectEntryPointParameters( EntryPointRequest* entryPoint, SubstitutionSet typeSubst) { - FuncDecl* entryPointFuncDecl = entryPoint->decl; - if (!entryPointFuncDecl) - { - // Something must have failed earlier, so that - // we didn't find a declaration to match this - // entry point request. - // - return; - } + DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); // We will take responsibility for creating and filling in // the `EntryPointLayout` object here. // RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout(); entryPointLayout->profile = entryPoint->profile; - entryPointLayout->entryPoint = entryPointFuncDecl; + entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); // The entry point layout must be added to the output // program layout so that it can be accessed by reflection. @@ -2522,12 +2517,12 @@ static void collectEntryPointParameters( scopeBuilder.beginLayout(context); auto paramsStructLayout = scopeBuilder.m_structLayout; - for( auto paramDecl : entryPointFuncDecl->getMembersOfType<ParamDecl>() ) + for( auto paramDeclRef : getMembersOfType<ParamDecl>(entryPointFuncDeclRef) ) { // Any error messages we emit during the process should // refer to the location of this parameter. // - state.loc = paramDecl->loc; + state.loc = paramDeclRef.getLoc(); // We are going to construct the variable layout for this // parameter *before* computing the type layout, because @@ -2536,13 +2531,13 @@ static void collectEntryPointParameters( // back onto the `VarLayout`. // RefPtr<VarLayout> paramVarLayout = new VarLayout(); - paramVarLayout->varDecl = makeDeclRef(paramDecl.Ptr()); + paramVarLayout->varDecl = paramDeclRef; paramVarLayout->stage = state.stage; auto paramTypeLayout = computeEntryPointParameterTypeLayout( context, typeSubst, - paramDecl, + paramDeclRef, paramVarLayout, state); paramVarLayout->typeLayout = paramTypeLayout; @@ -2598,10 +2593,10 @@ static void collectEntryPointParameters( // TODO: Ideally we should make the layout process more robust to empty/void // types and apply this logic unconditionally. // - auto resultType = entryPointFuncDecl->ReturnType.type; + auto resultType = GetResultType(entryPointFuncDeclRef)->Substitute(typeSubst).as<Type>(); if( !resultType->Equals(resultType->getSession()->getVoidType()) ) { - state.loc = entryPointFuncDecl->loc; + state.loc = entryPointFuncDeclRef.getLoc(); state.directionMask = kEntryPointParameterDirection_Output; RefPtr<VarLayout> resultLayout = new VarLayout(); @@ -2609,7 +2604,7 @@ static void collectEntryPointParameters( auto resultTypeLayout = processEntryPointVaryingParameterDecl( context, - entryPointFuncDecl, + entryPointFuncDeclRef.getDecl(), resultType->Substitute(typeSubst).as<Type>(), state, resultLayout); diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 076111711..b9e3786a2 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -4781,8 +4781,7 @@ namespace Slang addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK) DECL(typedef, ParseTypeDef); DECL(associatedtype, parseAssocType); - DECL(__generic_param, parseGlobalGenericParamDecl); - DECL(type_param, parseGlobalGenericParamDecl); + DECL(type_param, parseGlobalGenericParamDecl); DECL(cbuffer, parseHLSLCBufferDecl); DECL(tbuffer, parseHLSLTBufferDecl); DECL(__generic, ParseGenericDecl); diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 93fe687fd..b1faf6776 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -178,7 +178,7 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) END_SYNTAX_CLASS() SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions) - // the __generic_param decl to be substituted + // the type_param decl to be substituted DECL_FIELD(GlobalGenericParamDecl*, paramDecl) // the actual type to substitute in diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 6b58809a6..0af318c9e 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1427,7 +1427,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<Substitutions> GlobalGenericParamSubstitution::applySubstitutionsShallow(SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) { - // if we find a GlobalGenericParamSubstitution in subst that references the same __generic_param decl + // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl // return a copy of that GlobalGenericParamSubstitution int diff = 0; @@ -1884,6 +1884,11 @@ void Type::accept(IValVisitor* visitor, void* extra) return decl->nameAndLoc.name; } + SourceLoc DeclRefBase::getLoc() const + { + return decl->loc; + } + DeclRefBase DeclRefBase::GetParent() const { // Want access to the free function (the 'as' method by default gets priority) diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 2151ed764..344e94ff9 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -482,6 +482,7 @@ namespace Slang // Convenience accessors for common properties of declarations Name* GetName() const; + SourceLoc getLoc() const; DeclRefBase GetParent() const; int GetHashCode() const; diff --git a/tests/bugs/gh-357.slang b/tests/bugs/gh-357.slang index be2ba95ed..043eebf17 100644 --- a/tests/bugs/gh-357.slang +++ b/tests/bugs/gh-357.slang @@ -25,11 +25,9 @@ struct AssocImpl : IAssoc typedef BaseImpl TBase; }; -__generic_param T : IAssoc; - - [numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain<T:IAssoc>( + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; T.TBase base; diff --git a/tests/compute/assoctype-generic-arg.slang b/tests/compute/assoctype-generic-arg.slang index 4bc77c925..dd183ea5d 100644 --- a/tests/compute/assoctype-generic-arg.slang +++ b/tests/compute/assoctype-generic-arg.slang @@ -25,11 +25,9 @@ struct AssocImpl : IAssoc typedef BaseImpl TBase; }; -__generic_param T : IAssoc; - - [numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain<T : IAssoc>( + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; T.TBase base; diff --git a/tests/compute/global-type-param-array.slang b/tests/compute/global-type-param-array.slang index 87236d8f6..d801efe2c 100644 --- a/tests/compute/global-type-param-array.slang +++ b/tests/compute/global-type-param-array.slang @@ -6,17 +6,16 @@ RWStructuredBuffer<float> outputBuffer; import globalTypeParamArrayShared; -__generic_param TImpl : IBase; - -ParameterBlock<TImpl> impl; - float doCompute<T:IBase>(T t) { return t.compute(1.0); } [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain< + TImpl : IBase>( + uniform ParameterBlock<TImpl> impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float outVal = doCompute<TImpl>(impl); diff --git a/tests/compute/global-type-param-in-entrypoint.slang b/tests/compute/global-type-param-in-entrypoint.slang index 4bcf4cbca..9a1e9b054 100644 --- a/tests/compute/global-type-param-in-entrypoint.slang +++ b/tests/compute/global-type-param-in-entrypoint.slang @@ -8,7 +8,7 @@ interface IVertInterpolant float4 getColor(); } -__generic_param TVertInterpolant : IVertInterpolant; +type_param TVertInterpolant : IVertInterpolant; struct VertImpl : IVertInterpolant { diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang index f177dcb1d..7621f8961 100644 --- a/tests/compute/global-type-param.slang +++ b/tests/compute/global-type-param.slang @@ -26,10 +26,8 @@ struct Impl : IBase } }; -__generic_param TImpl : IBase; - [numthreads(1, 1, 1)] -void computeMain( +void computeMain<TImpl:IBase>( uniform TImpl impl, uint3 dispatchThreadID : SV_DispatchThreadID) { diff --git a/tests/compute/global-type-param1.slang b/tests/compute/global-type-param1.slang index e16ffa9da..f33be8ec7 100644 --- a/tests/compute/global-type-param1.slang +++ b/tests/compute/global-type-param1.slang @@ -26,10 +26,6 @@ struct Impl : IBase } }; -__generic_param TImpl : IBase; - -ParameterBlock<TImpl> impl; - cbuffer C { float base0; // = 0.5 @@ -39,7 +35,10 @@ Texture2D tex1; // = 0.0 SamplerState sampler; [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain< + TImpl : IBase>( + uniform ParameterBlock<TImpl> impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5 diff --git a/tests/compute/global-type-param2.slang b/tests/compute/global-type-param2.slang index f29d01407..976a31df8 100644 --- a/tests/compute/global-type-param2.slang +++ b/tests/compute/global-type-param2.slang @@ -38,10 +38,6 @@ struct Impl : IBase } }; -__generic_param TImpl : IBase; - -ParameterBlock<TImpl> impl; - // at binding c0: cbuffer existingBuffer { @@ -51,7 +47,10 @@ Texture2D tex1; // = 0.0 SamplerState sampler; [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain< + TImpl : IBase>( + uniform ParameterBlock<TImpl> impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5 diff --git a/tests/compute/int-generic.slang b/tests/compute/int-generic.slang index 6bb63df8c..d9eb85f82 100644 --- a/tests/compute/int-generic.slang +++ b/tests/compute/int-generic.slang @@ -29,14 +29,12 @@ struct Material<let A:int, let B: int> : IMaterial TBRDF getBRDF() { TBRDF a; a.c = 0; return a; } }; -type_param TMaterial : IMaterial; - -TMaterial material; - [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain<M : IMaterial>( + uniform M material, + uint3 dispatchThreadID : SV_DispatchThreadID) { - TMaterial.TBRDF brdf = material.getBRDF(); + M.TBRDF brdf = material.getBRDF(); int outVal = brdf.compute(); outputBuffer[dispatchThreadID.x] = outVal; }
\ No newline at end of file diff --git a/tests/compute/tagged-union.slang b/tests/compute/tagged-union.slang index 5089ec5a7..de69232f9 100644 --- a/tests/compute/tagged-union.slang +++ b/tests/compute/tagged-union.slang @@ -44,11 +44,13 @@ struct B : IFrobnicator } } +[numthreads(4, 1, 1)] +void computeMain // Now we will define the generic type parameter for our shader, // which will be constraints to be a type that implements our // `IFrobnicator` interface. // -type_param T : IFrobnicator; + <T : IFrobnicator> // // For the actual test runner, we will instruct it to plug in // a tagged-union type over the two concrete implemetnations. @@ -57,7 +59,7 @@ type_param T : IFrobnicator; // our intention when it is informed via the API. // //TEST_INPUT: type __TaggedUnion(A,B) - + ( // Next we need to pass in the actual parameter data for our // chosen `IFrobnicator` implementation. The decalration of // the constant buffer follows the conventional approach for @@ -68,7 +70,7 @@ type_param T : IFrobnicator; // the `render-test` tool doesn't yet support code that // uses multiple descriptor tables/sets. // -ConstantBuffer<T> gFrobnicator; + uniform ConstantBuffer<T> gFrobnicator, // Where things get interesting is when we go to provide the // data that will be used by the parameter block. @@ -92,17 +94,15 @@ ConstantBuffer<T> gFrobnicator; // //TEST_INPUT: cbuffer(data=[16 9 1 0 0], stride=4):dxbinding(0),glbinding(0) -int test(int val) -{ - return gFrobnicator.frobnicate(val); -} //TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(1),out -RWStructuredBuffer<int> gOutputBuffer; - -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) + uniform RWStructuredBuffer<int> gOutputBuffer, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; - gOutputBuffer[tid] = test(tid); + + int val = tid; + val = gFrobnicator.frobnicate(val); + + gOutputBuffer[tid] = val; } diff --git a/tests/reflection/global-type-params.slang b/tests/reflection/global-type-params.slang index 74961b7cc..290e6353a 100644 --- a/tests/reflection/global-type-params.slang +++ b/tests/reflection/global-type-params.slang @@ -7,8 +7,8 @@ interface IBase {}; -__generic_param TParam : IBase; -__generic_param TParam2 : IBase; +type_param TParam : IBase; +type_param TParam2 : IBase; struct S { |
