diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 86 | ||||
| -rw-r--r-- | source/slang/slang-compiler.cpp | 31 | ||||
| -rw-r--r-- | source/slang/slang-compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-decl-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-parser.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 3 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 6 | ||||
| -rw-r--r-- | tests/compute/global-generic-value-param.slang | 59 | ||||
| -rw-r--r-- | tests/compute/global-generic-value-param.slang.expected.txt | 16 | ||||
| -rw-r--r-- | tools/gfx/render.h | 6 | ||||
| -rw-r--r-- | tools/render-test/shader-input-layout.cpp | 34 | ||||
| -rw-r--r-- | tools/render-test/shader-input-layout.h | 6 | ||||
| -rw-r--r-- | tools/render-test/slang-support.cpp | 49 |
20 files changed, 294 insertions, 105 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 879f1c5fa..3e0a1c618 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -49,6 +49,11 @@ namespace Slang checkVarDeclCommon(varDecl); } + void visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) + { + checkVarDeclCommon(decl); + } + void visitImportDecl(ImportDecl* decl); void visitGenericTypeParamDecl(GenericTypeParamDecl* decl); @@ -114,6 +119,11 @@ namespace Slang checkVarDeclCommon(varDecl); } + void visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) + { + checkVarDeclCommon(decl); + } + void visitEnumCaseDecl(EnumCaseDecl* decl); void visitEnumDecl(EnumDecl* decl); diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 333690b36..c229e8f96 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -52,7 +52,8 @@ namespace Slang /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`. static void _collectExistentialSpecializationParamsRec( SpecializationParams& ioSpecializationParams, - Type* type) + Type* type, + SourceLoc loc) { // Whether or not something is an array does not affect // the number of existential slots it introduces. @@ -66,7 +67,8 @@ namespace Slang { _collectExistentialSpecializationParamsRec( ioSpecializationParams, - parameterGroupType->getElementType()); + parameterGroupType->getElementType(), + loc); return; } @@ -81,6 +83,7 @@ namespace Slang // SpecializationParam specializationParam; specializationParam.flavor = SpecializationParam::Flavor::ExistentialType; + specializationParam.loc = loc; specializationParam.object = type; ioSpecializationParams.add(specializationParam); } @@ -112,7 +115,8 @@ namespace Slang { _collectExistentialSpecializationParamsRec( ioSpecializationParams, - GetType(paramDeclRef)); + GetType(paramDeclRef), + paramDeclRef.getLoc()); } @@ -147,6 +151,7 @@ namespace Slang { SpecializationParam param; param.flavor = SpecializationParam::Flavor::GenericType; + param.loc = genericTypeParam->loc; param.object = genericTypeParam; m_genericSpecializationParams.add(param); } @@ -154,6 +159,7 @@ namespace Slang { SpecializationParam param; param.flavor = SpecializationParam::Flavor::GenericValue; + param.loc = genericValParam->loc; param.object = genericValParam; m_genericSpecializationParams.add(param); } @@ -1020,9 +1026,21 @@ static bool doesParameterMatch( // SpecializationParam specializationParam; specializationParam.flavor = SpecializationParam::Flavor::GenericType; + specializationParam.loc = globalGenericParam->loc; specializationParam.object = globalGenericParam; m_specializationParams.add(specializationParam); } + else if( auto globalGenericValueParam = as<GlobalGenericValueParamDecl>(globalDecl) ) + { + // A global generic type parameter declaration introduces + // a suitable specialization parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::GenericValue; + specializationParam.loc = globalGenericValueParam->loc; + specializationParam.object = globalGenericValueParam; + m_specializationParams.add(specializationParam); + } else if( auto importDecl = as<ImportDecl>(globalDecl) ) { // An `import` declaration creates a requirement dependency @@ -1425,9 +1443,6 @@ static bool doesParameterMatch( auto& arg = args[ii]; auto& param = m_specializationParams[ii]; - auto argType = arg.val.as<Type>(); - SLANG_ASSERT(argType); - switch( param.flavor ) { case SpecializationParam::Flavor::GenericType: @@ -1435,6 +1450,13 @@ static bool doesParameterMatch( auto genericTypeParamDecl = param.object.as<GlobalGenericParamDecl>(); SLANG_ASSERT(genericTypeParamDecl); + RefPtr<Type> argType = as<Type>(arg.val); + if(!argType) + { + sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, genericTypeParamDecl); + argType = getLinkage()->getSessionImpl()->getErrorType(); + } + // TODO: There is a serious flaw to this checking logic if we ever have cases where // the constraints on one `type_param` can depend on another `type_param`, e.g.: // @@ -1520,6 +1542,13 @@ static bool doesParameterMatch( auto interfaceType = param.object.as<Type>(); SLANG_ASSERT(interfaceType); + RefPtr<Type> argType = as<Type>(arg.val); + if(!argType) + { + sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, interfaceType); + argType = getLinkage()->getSessionImpl()->getErrorType(); + } + auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); if (!witness) { @@ -1539,6 +1568,29 @@ static bool doesParameterMatch( } break; + case SpecializationParam::Flavor::GenericValue: + { + auto paramDecl = param.object.as<GlobalGenericValueParamDecl>(); + SLANG_ASSERT(paramDecl); + + // Now we need to check that the argument `Val` has the + // appropriate type expected by the parameter. + + RefPtr<IntVal> intVal = as<IntVal>(arg.val); + if(!intVal) + { + sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); + intVal = new ConstantIntVal(0); + } + + ModuleSpecializationInfo::GenericArgInfo expandedArg; + expandedArg.paramDecl = paramDecl; + expandedArg.argVal = intVal; + + specializationInfo->genericArgs.add(expandedArg); + } + break; + default: SLANG_UNEXPECTED("unhandled specialization parameter flavor"); } @@ -1556,27 +1608,17 @@ static bool doesParameterMatch( { auto linkage = componentType->getLinkage(); + SharedSemanticsContext semanticsContext(linkage, sink); + SemanticsVisitor semanticsVisitor(&semanticsContext); + auto argCount = argExprs.getCount(); for(Index ii = 0; ii < argCount; ++ii ) { auto argExpr = argExprs[ii]; auto paramInfo = componentType->getSpecializationParam(ii); - // TODO: We should support non-type arguments here - - auto argType = checkProperType(linkage, TypeExp(argExpr), sink); - if( !argType ) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(argExpr, - Diagnostics::expectedAType, - argExpr->type); - continue; - } - SpecializationArg arg; - arg.val = argType; + arg.val = semanticsVisitor.ExtractGenericArgVal(argExpr); outArgs.add(arg); } } @@ -1757,7 +1799,7 @@ static bool doesParameterMatch( RefPtr<Expr> argExpr; for (auto & s : scopesToTry) { - argExpr = linkage->parseTypeString(name, s); + argExpr = linkage->parseTermString(name, s); argExpr = semantics.CheckTerm(argExpr); if( argExpr ) { @@ -1790,7 +1832,7 @@ static bool doesParameterMatch( SemanticsVisitor visitor(&sharedSemanticsContext); SpecializationParams specializationParams; - _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType); + _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType, SourceLoc()); assert(specializationParams.getCount() == argCount); diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 2ab376181..3cfb845d7 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2382,6 +2382,37 @@ SlangResult dissassembleDXILUsingDXC( BackEndCompileRequest* compileRequest, EndToEndCompileRequest* endToEndReq) { + // If we are about to generate output code, but we still + // have unspecialized generic/existential parameters, + // then there is a problem. + // + auto program = compileRequest->getProgram(); + auto specializationParamCount = program->getSpecializationParamCount(); + if( specializationParamCount != 0 ) + { + auto sink = compileRequest->getSink(); + + for( Index ii = 0; ii < specializationParamCount; ++ii ) + { + auto specializationParam = program->getSpecializationParam(ii); + if( auto decl = as<Decl>(specializationParam.object) ) + { + sink->diagnose(specializationParam.loc, Diagnostics::specializationParameterOfNameNotSpecialized, decl); + } + else if( auto type = as<Type>(specializationParam.object) ) + { + sink->diagnose(specializationParam.loc, Diagnostics::specializationParameterOfNameNotSpecialized, type); + } + else + { + sink->diagnose(specializationParam.loc, Diagnostics::specializationParameterNotSpecialized); + } + } + + return; + } + + // Go through the code-generation targets that the user // has specified, and generate code for each of them. // diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 69513ada6..a628e86a3 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1206,7 +1206,7 @@ namespace Slang /// SlangResult loadFile(String const& path, PathInfo& outPathInfo, ISlangBlob** outBlob); - RefPtr<Expr> parseTypeString(String typeStr, RefPtr<Scope> scope); + RefPtr<Expr> parseTermString(String str, RefPtr<Scope> scope); Type* specializeType( Type* unspecializedType, diff --git a/source/slang/slang-decl-defs.h b/source/slang/slang-decl-defs.h index 04c733aac..0e3159910 100644 --- a/source/slang/slang-decl-defs.h +++ b/source/slang/slang-decl-defs.h @@ -180,6 +180,11 @@ END_SYNTAX_CLASS() SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl) END_SYNTAX_CLASS() +// A `__generic_value_param` declaration, which defines an existential +// value parameter (not a type parameter. +SYNTAX_CLASS(GlobalGenericValueParamDecl, VarDeclBase) +END_SYNTAX_CLASS() + // A scope for local declarations (e.g., as part of a statement) SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 485f463e8..4e48cc95b 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -363,11 +363,16 @@ DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entr DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") -DIAGNOSTIC(38005, Error, globalGenericArgumentNotAType, "argument for global generic parameter '$0' must be a type") +DIAGNOSTIC(38005, Error, expectedTypeForSpecializationArg, "expected a type as argument for specialization parameter '$0'") DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage") DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage <name>' command-line option to specify a stage") +DIAGNOSTIC(38008, Error, specializationParameterOfNameNotSpecialized, "no specialization argument was provided for specialization parameter '$0'") +DIAGNOSTIC(38008, Error, specializationParameterNotSpecialized, "no specialization argument was provided for specialization parameter") + +DIAGNOSTIC(38009, Error, expectedValueOfTypeForSpecializationArg, "expected a constant value of type '$0' as argument for specialization parameter '$1'") + DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'") DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type") DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration") @@ -388,7 +393,6 @@ DIAGNOSTIC(38025, Error, mismatchSpecializationArguments, "expected $0 specializ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") -DIAGNOSTIC(38028, Error, existentialSlotArgNotAType, "existential slot argument $0 was not a type") DIAGNOSTIC(38029, Error,typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index b6a3c3e93..79f8e2b80 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1926,7 +1926,18 @@ struct IRBuilder UInt caseArgCount, IRInst* const* caseArgs); - IRGlobalGenericParam* emitGlobalGenericParam(); + IRGlobalGenericParam* emitGlobalGenericParam( + IRType* type); + + IRGlobalGenericParam* emitGlobalGenericTypeParam() + { + return emitGlobalGenericParam(getTypeKind()); + } + + IRGlobalGenericParam* emitGlobalGenericWitnessTableParam() + { + return emitGlobalGenericParam(getWitnessTableType()); + } IRBindGlobalGenericParam* emitBindGlobalGenericParam( IRInst* param, diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 80b0cd39e..59cc625f5 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -552,7 +552,7 @@ IRGlobalGenericParam* cloneGlobalGenericParamImpl( IRGlobalGenericParam* originalVal, IROriginalValuesForClone const& originalValues) { - auto clonedVal = builder->emitGlobalGenericParam(); + auto clonedVal = builder->emitGlobalGenericParam(originalVal->getFullType()); cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); return clonedVal; } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 31224fde2..a30445682 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3328,12 +3328,13 @@ namespace Slang return inst; } - IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam() + IRGlobalGenericParam* IRBuilder::emitGlobalGenericParam( + IRType* type) { IRGlobalGenericParam* irGenericParam = createInst<IRGlobalGenericParam>( this, kIROp_GlobalGenericParam, - nullptr); + type); addGlobalValue(this, irGenericParam); return irGenericParam; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e2689ffd1..ac02e1dfd 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4211,7 +4211,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // This is a constraint on a global generic type parameters, // and so it should lower as a parameter of its own. - auto inst = getBuilder()->emitGlobalGenericParam(); + auto inst = getBuilder()->emitGlobalGenericWitnessTableParam(); addLinkageDecoration(context, inst, decl); return LoweredValInfo::simple(inst); } @@ -4227,11 +4227,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo visitGlobalGenericParamDecl(GlobalGenericParamDecl* decl) { - auto inst = getBuilder()->emitGlobalGenericParam(); + auto inst = getBuilder()->emitGlobalGenericTypeParam(); addLinkageDecoration(context, inst, decl); return LoweredValInfo::simple(inst); } + LoweredValInfo visitGlobalGenericValueParamDecl(GlobalGenericValueParamDecl* decl) + { + auto builder = getBuilder(); + auto type = lowerType(context, decl->type); + auto inst = builder->emitGlobalGenericParam(type); + addLinkageDecoration(context, inst, decl); + return LoweredValInfo::simple(inst); + } + + void lowerWitnessTable( IRGenContext* subContext, WitnessTable* astWitnessTable, diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index dad84a4b6..9be2e49e5 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1788,6 +1788,11 @@ namespace Slang return taggedUnionType; } + static RefPtr<RefObject> parseTaggedUnionType(Parser* parser, void* /*unused*/) + { + return parseTaggedUnionType(parser); + } + static TypeSpec parseTypeSpec(Parser* parser) { TypeSpec typeSpec; @@ -1831,6 +1836,12 @@ namespace Slang typeSpec.expr = createDeclRefType(parser, decl); return typeSpec; } + // TODO: This case would not be needed if we had the + // code below dispatch into `parseAtomicExpr`, which + // already includes logic for keyword lookup. + // + // Leaving this case here for now to avoid breaking anything. + // else if(AdvanceIf(parser, "__TaggedUnion")) { typeSpec.expr = parseTaggedUnionType(parser); @@ -2386,7 +2397,7 @@ namespace Slang return assocTypeDecl; } - RefPtr<RefObject> parseGlobalGenericParamDecl(Parser * parser, void *) + RefPtr<RefObject> parseGlobalGenericTypeParamDecl(Parser * parser, void *) { RefPtr<GlobalGenericParamDecl> genParamDecl = new GlobalGenericParamDecl(); auto nameToken = parser->ReadToken(TokenType::Identifier); @@ -2397,6 +2408,27 @@ namespace Slang return genParamDecl; } + RefPtr<RefObject> parseGlobalGenericValueParamDecl(Parser * parser, void *) + { + RefPtr<GlobalGenericValueParamDecl> genericParamDecl = new GlobalGenericValueParamDecl(); + auto nameToken = parser->ReadToken(TokenType::Identifier); + genericParamDecl->nameAndLoc = NameLoc(nameToken); + genericParamDecl->loc = nameToken.loc; + + if(AdvanceIf(parser, TokenType::Colon)) + { + genericParamDecl->type = parser->ParseTypeExp(); + } + + if(AdvanceIf(parser, TokenType::OpAssign)) + { + genericParamDecl->initExpr = parser->ParseInitExpr(); + } + + parser->ReadToken(TokenType::Semicolon); + return genericParamDecl; + } + static RefPtr<RefObject> parseInterfaceDecl(Parser* parser, void* /*userData*/) { RefPtr<InterfaceDecl> decl = new InterfaceDecl(); @@ -4270,7 +4302,7 @@ namespace Slang return parsePrefixExpr(this); } - RefPtr<Expr> parseTypeFromSourceFile( + RefPtr<Expr> parseTermFromSourceFile( Session* session, TokenSpan const& tokens, DiagnosticSink* sink, @@ -4282,7 +4314,7 @@ namespace Slang parser.currentScope = outerScope; parser.namePool = namePool; parser.sourceLanguage = sourceLanguage; - return parser.ParseType(); + return parser.ParseExpression(); } // Parse a source file into an existing translation unit @@ -4649,7 +4681,7 @@ namespace Slang addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK) DECL(typedef, ParseTypeDef); DECL(associatedtype, parseAssocType); - DECL(type_param, parseGlobalGenericParamDecl); + DECL(type_param, parseGlobalGenericTypeParamDecl); DECL(cbuffer, parseHLSLCBufferDecl); DECL(tbuffer, parseHLSLTBufferDecl); DECL(__generic, ParseGenericDecl); @@ -4666,6 +4698,7 @@ namespace Slang DECL(var, parseVarDecl); DECL(func, parseFuncDecl); DECL(typealias, parseTypeAliasDecl); + DECL(__generic_value_param, parseGlobalGenericValueParamDecl); #undef DECL @@ -4753,6 +4786,7 @@ namespace Slang EXPR(this, parseThisExpr); EXPR(true, parseTrueExpr); EXPR(false, parseFalseExpr); + EXPR(__TaggedUnion, parseTaggedUnionType); #undef EXPR diff --git a/source/slang/slang-parser.h b/source/slang/slang-parser.h index 98fd9ed65..1c21b9474 100644 --- a/source/slang/slang-parser.h +++ b/source/slang/slang-parser.h @@ -14,7 +14,7 @@ namespace Slang DiagnosticSink* sink, RefPtr<Scope> const& outerScope); - RefPtr<Expr> parseTypeFromSourceFile( + RefPtr<Expr> parseTermFromSourceFile( Session* session, TokenSpan const& tokens, DiagnosticSink* sink, diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 8c07855e4..0dd70b1f8 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -1241,7 +1241,8 @@ namespace Slang ExistentialValue, }; Flavor flavor; - RefPtr<RefObject> object; + SourceLoc loc; + RefPtr<NodeBase> object; }; typedef List<SpecializationParam> SpecializationParams; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 6632f2fa3..98ef7400e 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -813,7 +813,7 @@ SlangResult Linkage::loadFile(String const& path, PathInfo& outPathInfo, ISlangB return SLANG_OK; } -RefPtr<Expr> Linkage::parseTypeString(String typeStr, RefPtr<Scope> scope) +RefPtr<Expr> Linkage::parseTermString(String typeStr, RefPtr<Scope> scope) { // Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc will be cleaned up SourceManager localSourceManager; @@ -856,7 +856,7 @@ RefPtr<Expr> Linkage::parseTypeString(String typeStr, RefPtr<Scope> scope) this, nullptr); - return parseTypeFromSourceFile( + return parseTermFromSourceFile( getSessionImpl(), tokens, &sink, scope, getNamePool(), SourceLanguage::Slang); } @@ -893,7 +893,7 @@ Type* ComponentType::getTypeFromString( auto linkage = getLinkage(); for(auto& s : scopesToTry) { - RefPtr<Expr> typeExpr = linkage->parseTypeString( + RefPtr<Expr> typeExpr = linkage->parseTermString( typeStr, s); type = checkProperType(linkage, TypeExp(typeExpr), sink); if (type && !type.as<ErrorType>()) diff --git a/tests/compute/global-generic-value-param.slang b/tests/compute/global-generic-value-param.slang new file mode 100644 index 000000000..c20e32784 --- /dev/null +++ b/tests/compute/global-generic-value-param.slang @@ -0,0 +1,59 @@ +// global-generic-value-param.slang + +//TEST(compute):COMPARE_COMPUTE: + +// This is a basic test of support for global generic +// value parameters: explicit named parameters at global +// scope that can be used to generate specialized kernel +// code based on different values. + +// We start by declaring a global generic value parameter: +// +// Note: only `int` parameters are expected to work for now. +// Note: the default `= 0` intializer isn't used right now. +// +__generic_value_param kOffset : uint = 0; + +// For the test framework, we also need to specify what +// value we want to specialize to. +// +// Note: this value (7) will be fed in to the compiler API +// as a specialization argument, and will not be visible +// to the compiler when it initially compiles the code +// to IR. +// +//TEST_INPUT: globalSpecializationArg 7 + +// Next we will declare a buffer of data just so that we +// can index into something and make the shader logic a +// bit less trivial. +// +RWStructuredBuffer<uint> vals; +//TEST_INPUT: ubuffer(data=[0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15], stride=4):dxbinding(0),glbinding(0) + +// The core test function will use the `kOffset` value +// we declared above along with the input value (the +// thread ID) to index into our buffer of values and +// compute a result. All of the math here is just to +// make the result easy to validate by eye. +// +uint test(uint value) +{ + return value * 16 + vals[(value + kOffset) & 0xF]; +} + +// And finally we have the boilerplate cruft that almost +// all of our compute tests use. + +//TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):dxbinding(0),glbinding(1),out +RWStructuredBuffer<uint> outputBuffer; + +[numthreads(16, 1, 1)] +void computeMain( + uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + uint inVal = tid; + uint outVal = test(inVal); + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/global-generic-value-param.slang.expected.txt b/tests/compute/global-generic-value-param.slang.expected.txt new file mode 100644 index 000000000..d2f0c54e3 --- /dev/null +++ b/tests/compute/global-generic-value-param.slang.expected.txt @@ -0,0 +1,16 @@ +7 +18 +29 +3A +4B +5C +6D +7E +8F +90 +A1 +B2 +C3 +D4 +E5 +F6 diff --git a/tools/gfx/render.h b/tools/gfx/render.h index 30373b356..65f3c00c0 100644 --- a/tools/gfx/render.h +++ b/tools/gfx/render.h @@ -146,10 +146,8 @@ struct ShaderCompileRequest EntryPoint vertexShader; EntryPoint fragmentShader; EntryPoint computeShader; - Slang::List<Slang::String> globalGenericTypeArguments; - Slang::List<Slang::String> entryPointGenericTypeArguments; - Slang::List<Slang::String> entryPointExistentialTypeArguments; - Slang::List<Slang::String> globalExistentialTypeArguments; + Slang::List<Slang::String> globalSpecializationArgs; + Slang::List<Slang::String> entryPointSpecializationArgs; Slang::List<Slang::CommandLine::Arg> compileArgs; }; diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 37aa5dab6..3ba20c2cb 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -89,10 +89,8 @@ namespace renderer_test void ShaderInputLayout::parse(RandomGenerator* rand, const char * source) { entries.clear(); - globalGenericTypeArguments.clear(); - entryPointGenericTypeArguments.clear(); - globalExistentialTypeArguments.clear(); - entryPointExistentialTypeArguments.clear(); + globalSpecializationArgs.clear(); + entryPointSpecializationArgs.clear(); auto lines = Split(source, '\n'); for (auto & line : lines) { @@ -102,37 +100,25 @@ namespace renderer_test TokenReader parser(lineContent); try { - if (parser.LookAhead("type")) + if (parser.LookAhead("entryPointSpecializationArg") + || parser.LookAhead("type") + || parser.LookAhead("entryPointExistentialType")) { parser.ReadToken(); StringBuilder typeExp; while (!parser.IsEnd()) typeExp << parser.ReadToken().Content; - entryPointGenericTypeArguments.add(typeExp); + entryPointSpecializationArgs.add(typeExp); } - else if (parser.LookAhead("global_type")) + else if (parser.LookAhead("globalSpecializationArg") + || parser.LookAhead("global_type") + || parser.LookAhead("globalExistentialType")) { parser.ReadToken(); StringBuilder typeExp; while (!parser.IsEnd()) typeExp << parser.ReadToken().Content; - globalGenericTypeArguments.add(typeExp); - } - else if (parser.LookAhead("globalExistentialType")) - { - parser.ReadToken(); - StringBuilder typeExp; - while (!parser.IsEnd()) - typeExp << parser.ReadToken().Content; - globalExistentialTypeArguments.add(typeExp); - } - else if (parser.LookAhead("entryPointExistentialType")) - { - parser.ReadToken(); - StringBuilder typeExp; - while (!parser.IsEnd()) - typeExp << parser.ReadToken().Content; - entryPointExistentialTypeArguments.add(typeExp); + globalSpecializationArgs.add(typeExp); } else { diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index 8a00980d9..e57e1ed8b 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -85,10 +85,8 @@ class ShaderInputLayout { public: Slang::List<ShaderInputLayoutEntry> entries; - Slang::List<Slang::String> globalGenericTypeArguments; - Slang::List<Slang::String> entryPointGenericTypeArguments; - Slang::List<Slang::String> globalExistentialTypeArguments; - Slang::List<Slang::String> entryPointExistentialTypeArguments; + Slang::List<Slang::String> globalSpecializationArgs; + Slang::List<Slang::String> entryPointSpecializationArgs; int numRenderTargets = 1; Slang::Index findEntryIndexByName(const Slang::String& name) const; diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 5b038d208..9a7c13d7a 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -119,31 +119,18 @@ static const char computeEntryPointName[] = "computeMain"; computeTranslationUnit = translationUnit; } - - Slang::List<const char*> rawGlobalTypeNames; - for (auto typeName : request.globalGenericTypeArguments) - rawGlobalTypeNames.add(typeName.getBuffer()); - spSetGlobalGenericArgs( - slangRequest, - (int)rawGlobalTypeNames.getCount(), - rawGlobalTypeNames.getBuffer()); - - Slang::List<const char*> rawEntryPointTypeNames; - for (auto typeName : request.entryPointGenericTypeArguments) - rawEntryPointTypeNames.add(typeName.getBuffer()); - - const int globalExistentialTypeCount = int(request.globalExistentialTypeArguments.getCount()); - for(int ii = 0; ii < globalExistentialTypeCount; ++ii ) + const int globalSpecializationArgCount = int(request.globalSpecializationArgs.getCount()); + for(int ii = 0; ii < globalSpecializationArgCount; ++ii ) { - spSetTypeNameForGlobalExistentialTypeParam(slangRequest, ii, request.globalExistentialTypeArguments[ii].getBuffer()); + spSetTypeNameForGlobalExistentialTypeParam(slangRequest, ii, request.globalSpecializationArgs[ii].getBuffer()); } - const int entryPointExistentialTypeCount = int(request.entryPointExistentialTypeArguments.getCount()); - auto setEntryPointExistentialTypeArgs = [&](int entryPoint) + const int entryPointSpecializationArgCount = int(request.entryPointSpecializationArgs.getCount()); + auto setEntryPointSpecializationArgs = [&](int entryPoint) { - for( int ii = 0; ii < entryPointExistentialTypeCount; ++ii ) + for( int ii = 0; ii < entryPointSpecializationArgCount; ++ii ) { - spSetTypeNameForEntryPointExistentialTypeParam(slangRequest, entryPoint, ii, request.entryPointExistentialTypeArguments[ii].getBuffer()); + spSetTypeNameForEntryPointExistentialTypeParam(slangRequest, entryPoint, ii, request.entryPointSpecializationArgs[ii].getBuffer()); } }; @@ -152,13 +139,11 @@ static const char computeEntryPointName[] = "computeMain"; int computeEntryPointIndex = 0; if(!gOptions.dontAddDefaultEntryPoints) { - computeEntryPointIndex = spAddEntryPointEx(slangRequest, computeTranslationUnit, + computeEntryPointIndex = spAddEntryPoint(slangRequest, computeTranslationUnit, computeEntryPointName, - SLANG_STAGE_COMPUTE, - (int)rawEntryPointTypeNames.getCount(), - rawEntryPointTypeNames.getBuffer()); + SLANG_STAGE_COMPUTE); - setEntryPointExistentialTypeArgs(computeEntryPointIndex); + setEntryPointSpecializationArgs(computeEntryPointIndex); } spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE); @@ -207,11 +192,11 @@ static const char computeEntryPointName[] = "computeMain"; int fragmentEntryPoint = 1; if( !gOptions.dontAddDefaultEntryPoints ) { - vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX, (int)rawEntryPointTypeNames.getCount(), rawEntryPointTypeNames.getBuffer()); - fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT, (int)rawEntryPointTypeNames.getCount(), rawEntryPointTypeNames.getBuffer()); + vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX); + fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT); - setEntryPointExistentialTypeArgs(vertexEntryPoint); - setEntryPointExistentialTypeArgs(fragmentEntryPoint); + setEntryPointSpecializationArgs(vertexEntryPoint); + setEntryPointSpecializationArgs(fragmentEntryPoint); } const SlangResult res = spCompile(slangRequest); @@ -340,10 +325,8 @@ static const char computeEntryPointName[] = "computeMain"; compileRequest.computeShader.source = sourceInfo; compileRequest.computeShader.name = computeEntryPointName; } - compileRequest.globalGenericTypeArguments = layout.globalGenericTypeArguments; - compileRequest.entryPointGenericTypeArguments = layout.entryPointGenericTypeArguments; - compileRequest.globalExistentialTypeArguments = layout.globalExistentialTypeArguments; - compileRequest.entryPointExistentialTypeArguments = layout.entryPointExistentialTypeArguments; + compileRequest.globalSpecializationArgs = layout.globalSpecializationArgs; + compileRequest.entryPointSpecializationArgs = layout.entryPointSpecializationArgs; return ShaderCompilerUtil::compileProgram(session, input, compileRequest, output.output); } |
