diff options
| -rw-r--r-- | slang.h | 55 | ||||
| -rw-r--r-- | source/slang/check.cpp | 83 | ||||
| -rw-r--r-- | source/slang/options.cpp | 4 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 1 | ||||
| -rw-r--r-- | source/slang/reflection.cpp | 64 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 12 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 1 | ||||
| -rw-r--r-- | tests/reflection/global-type-params.slang | 35 | ||||
| -rw-r--r-- | tests/reflection/global-type-params.slang.expected | 127 | ||||
| -rw-r--r-- | tools/slang-reflection-test/main.cpp | 66 |
10 files changed, 398 insertions, 50 deletions
@@ -133,6 +133,9 @@ extern "C" /* Do as little mangling of names as possible, to try to preserve original names */ SLANG_COMPILE_FLAG_NO_MANGLING = 1 << 3, + + /* Skip code generation step, just check the code and generate layout */ + SLANG_COMPILE_FLAG_NO_CODEGEN = 1 << 4, }; /*! @@ -501,6 +504,7 @@ extern "C" typedef struct SlangReflectionTypeLayout SlangReflectionTypeLayout; typedef struct SlangReflectionVariable SlangReflectionVariable; typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout; + typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; // get reflection data from a compilation request SLANG_API SlangReflection* spGetReflection( @@ -523,7 +527,8 @@ extern "C" SLANG_TYPE_KIND_TEXTURE_BUFFER, SLANG_TYPE_KIND_SHADER_STORAGE_BUFFER, SLANG_TYPE_KIND_PARAMETER_BLOCK, - + SLANG_TYPE_KIND_GENERIC_TYPE_PARAMETER, + SLANG_TYPE_KIND_INTERFACE, SLANG_TYPE_KIND_COUNT, }; @@ -719,11 +724,20 @@ extern "C" SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( SlangReflectionEntryPoint* entryPoint); + // SlangReflectionTypeParameter + SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter* typeParam); + SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter* typeParam); + SLANG_API unsigned spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* typeParam); + SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter* typeParam, unsigned int index); + // Shader Reflection SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* reflection); SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflection* reflection, unsigned index); + SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection* reflection); + SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection* reflection, unsigned int index); + SLANG_API SlangReflectionTypeParameter* spReflection_FindTypeParameter(SlangReflection* reflection, char const* name); SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection); @@ -762,6 +776,8 @@ namespace slang TextureBuffer = SLANG_TYPE_KIND_TEXTURE_BUFFER, ShaderStorageBuffer = SLANG_TYPE_KIND_SHADER_STORAGE_BUFFER, ParameterBlock = SLANG_TYPE_KIND_PARAMETER_BLOCK, + GenericTypeParameter = SLANG_TYPE_KIND_GENERIC_TYPE_PARAMETER, + Interface = SLANG_TYPE_KIND_INTERFACE }; enum ScalarType : SlangScalarType @@ -1103,7 +1119,7 @@ namespace slang { return spReflectionEntryPoint_getParameterCount((SlangReflectionEntryPoint*) this); } - + VariableLayoutReflection* getParameterByIndex(unsigned index) { return (VariableLayoutReflection*) spReflectionEntryPoint_getParameterByIndex((SlangReflectionEntryPoint*) this, index); @@ -1127,12 +1143,47 @@ namespace slang } }; + struct TypeParameterReflection + { + char const* getName() + { + return spReflectionTypeParameter_GetName((SlangReflectionTypeParameter*) this); + } + unsigned getIndex() + { + return spReflectionTypeParameter_GetIndex((SlangReflectionTypeParameter*) this); + } + unsigned getConstraintCount() + { + return spReflectionTypeParameter_GetConstraintCount((SlangReflectionTypeParameter*) this); + } + TypeReflection* getConstraintByIndex(int index) + { + return (TypeReflection*)spReflectionTypeParameter_GetConstraintByIndex((SlangReflectionTypeParameter*) this, index); + } + }; + struct ShaderReflection { unsigned getParameterCount() { return spReflection_GetParameterCount((SlangReflection*) this); } + + unsigned getTypeParameterCount() + { + return spReflection_GetTypeParameterCount((SlangReflection*) this); + } + + TypeParameterReflection* getTypeParameterByIndex(unsigned index) + { + return (TypeParameterReflection*)spReflection_GetTypeParameterByIndex((SlangReflection*) this, index); + } + + TypeParameterReflection* findTypeParameter(char const* name) + { + return (TypeParameterReflection*)spReflection_FindTypeParameter((SlangReflection*)this, name); + } VariableLayoutReflection* getParameterByIndex(unsigned index) { diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 5141d8634..bc5d144b0 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -6798,50 +6798,55 @@ namespace Slang } entryPoint->genericParameterTypes.Add(type); } - // check that user-provioded type arguments conforms to the generic type - // parameter declaration of this translation unit - - // collect global generic parameters from all imported modules - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - // add current translation unit first - { - auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - // add imported modules - for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) - { - auto moduleDecl = loadedModule->moduleDecl; - auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count()) + + // validate global type arguments only when we are generating code + if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) { - sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(), + // check that user-provioded type arguments conforms to the generic type + // parameter declaration of this translation unit + + // collect global generic parameters from all imported modules + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + // add current translation unit first + { + auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + // add imported modules + for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) + { + auto moduleDecl = loadedModule->moduleDecl; + auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count()) + { + sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(), entryPoint->genericParameterTypes.Count()); - return; - } - // if number of entry-point type arguments matches parameters, try find - // SubtypeWitness for each argument - int index = 0; - for (auto & gParam : globalGenericParams) - { - for (auto constraint : gParam->getMembersOfType<GenericTypeConstraintDecl>()) + return; + } + // if number of entry-point type arguments matches parameters, try find + // SubtypeWitness for each argument + int index = 0; + for (auto & gParam : globalGenericParams) { - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); - SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType); - if (!witness) + for (auto constraint : gParam->getMembersOfType<GenericTypeConstraintDecl>()) { - sink->diagnose(gParam, - Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index], - interfaceType); + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); + auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType); + if (!witness) + { + sink->diagnose(gParam, + Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index], + interfaceType); + } + entryPoint->genericParameterWitnesses.Add(witness); } - entryPoint->genericParameterWitnesses.Add(witness); + index++; } - index++; } if (sink->errorCount != 0) return; @@ -6851,8 +6856,6 @@ namespace Slang // if they are of types that are appropriate to the stage, etc. } - - void checkTranslationUnit( TranslationUnitRequest* translationUnit) { diff --git a/source/slang/options.cpp b/source/slang/options.cpp index 97deeb544..a5ed8eca6 100644 --- a/source/slang/options.cpp +++ b/source/slang/options.cpp @@ -278,6 +278,10 @@ struct OptionsParser { flags |= SLANG_COMPILE_FLAG_NO_MANGLING; } + else if (argStr == "-no-codegen") + { + flags |= SLANG_COMPILE_FLAG_NO_CODEGEN; + } else if(argStr == "-dump-ir" ) { requestImpl->shouldDumpIR = true; diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 4ec4f6fd5..6145015f1 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -700,6 +700,7 @@ static void collectGlobalGenericParameter( layout->decl = paramDecl; layout->index = (int)context->shared->programLayout->globalGenericParams.Count(); context->shared->programLayout->globalGenericParams.Add(layout); + context->shared->programLayout->globalGenericParamsMap[layout->decl->getName()->text] = layout.Ptr(); } // Collect a single declaration into our set of parameters diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index bd95f48fa..5270df8b4 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -39,6 +39,11 @@ static inline SlangReflectionTypeLayout* convert(TypeLayout* type) return (SlangReflectionTypeLayout*) type; } +static inline GenericParamLayout* convert(SlangReflectionTypeParameter * typeParam) +{ + return (GenericParamLayout*)typeParam; +} + static inline VarDeclBase* convert(SlangReflectionVariable* var) { return (VarDeclBase*) var; @@ -126,7 +131,6 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) { return SLANG_TYPE_KIND_RESOURCE; } - // TODO: need a better way to handle this stuff... #define CASE(TYPE) \ else if(type->As<TYPE>()) do { \ @@ -153,6 +157,14 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) { return SLANG_TYPE_KIND_STRUCT; } + else if (auto genericParamType = declRef.As<GlobalGenericParamDecl>()) + { + return SLANG_TYPE_KIND_GENERIC_TYPE_PARAMETER; + } + else if (auto interfaceType = declRef.As<InterfaceDecl>()) + { + return SLANG_TYPE_KIND_INTERFACE; + } } else if (auto errorType = type->As<ErrorType>()) { @@ -848,7 +860,7 @@ namespace Slang return 0; } - + static VarLayout* getParameterByIndex(RefPtr<TypeLayout> typeLayout, unsigned index) { if(auto parameterGroupLayout = typeLayout.As<ParameterGroupTypeLayout>()) @@ -974,6 +986,33 @@ SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( return (entryPointLayout->flags & EntryPointLayout::Flag::usesAnySampleRateInput) != 0; } +// SlangReflectionTypeParameter +SLANG_API char const* spReflectionTypeParameter_GetName(SlangReflectionTypeParameter * inTypeParam) +{ + auto typeParam = convert(inTypeParam); + return typeParam->decl->getName()->text.Buffer(); +} + +SLANG_API unsigned spReflectionTypeParameter_GetIndex(SlangReflectionTypeParameter * inTypeParam) +{ + auto typeParam = convert(inTypeParam); + return (unsigned)(typeParam->index); +} + +SLANG_API unsigned int spReflectionTypeParameter_GetConstraintCount(SlangReflectionTypeParameter* inTypeParam) +{ + auto typeParam = convert(inTypeParam); + auto constraints = typeParam->decl->getMembersOfType<GenericTypeConstraintDecl>(); + return (unsigned int)constraints.Count(); +} + +SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(SlangReflectionTypeParameter * inTypeParam, unsigned index) +{ + auto typeParam = convert(inTypeParam); + auto constraints = typeParam->decl->getMembersOfType<GenericTypeConstraintDecl>(); + return (SlangReflectionType*)constraints.ToArray()[index]->sup.Ptr(); +} + // Shader Reflection namespace Slang @@ -1006,6 +1045,27 @@ SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflec return convert(globalStructLayout->fields[index].Ptr()); } +SLANG_API unsigned int spReflection_GetTypeParameterCount(SlangReflection * reflection) +{ + auto program = convert(reflection); + return (unsigned int)program->globalGenericParams.Count(); +} + +SLANG_API SlangReflectionTypeParameter* spReflection_GetTypeParameterByIndex(SlangReflection * reflection, unsigned int index) +{ + auto program = convert(reflection); + return (SlangReflectionTypeParameter*)program->globalGenericParams[index].Ptr(); +} + +SLANG_API SlangReflectionTypeParameter * spReflection_FindTypeParameter(SlangReflection * inProgram, char const * name) +{ + auto program = convert(inProgram); + if (!program) return nullptr; + GenericParamLayout * result = nullptr; + program->globalGenericParamsMap.TryGetValue(name, result); + return (SlangReflectionTypeParameter*)result; +} + SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram) { auto program = convert(inProgram); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index ea86663ea..3156e5008 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -293,9 +293,12 @@ int CompileRequest::executeActionsInner() if (mSink.GetErrorCount() != 0) return 1; - // Generate initial IR for all the translation - // units, if we are in a mode where IR is called for. - generateIR(); + if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + { + // Generate initial IR for all the translation + // units, if we are in a mode where IR is called for. + generateIR(); + } if (mSink.GetErrorCount() != 0) return 1; @@ -315,7 +318,8 @@ int CompileRequest::executeActionsInner() // If command line specifies to skip codegen, we exit here. // Note: this is a debugging option. - if (shouldSkipCodegen) + if (shouldSkipCodegen || + ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) return 0; // Generate output code, in whatever format was requested diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index 6874fc460..904dacd91 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -450,6 +450,7 @@ public: List<RefPtr<EntryPointLayout>> entryPoints; List<RefPtr<GenericParamLayout>> globalGenericParams; + Dictionary<String, GenericParamLayout*> globalGenericParamsMap; // HACK: binding to use when we have to create // a dummy sampler just to appease glslang diff --git a/tests/reflection/global-type-params.slang b/tests/reflection/global-type-params.slang new file mode 100644 index 000000000..bfeb7fb2e --- /dev/null +++ b/tests/reflection/global-type-params.slang @@ -0,0 +1,35 @@ +//TEST:REFLECTION:-profile ps_4_0 -target hlsl -no-codegen + +// Confirm that we handle global generic parameters + + +float4 u; + +interface IBase +{}; + +__generic_param TParam : IBase; +__generic_param TParam2 : IBase; + +struct S +{ + TParam2 p; +}; + +ParameterBlock<S> arg; +ParameterBlock<TParam> arg1; + +Texture2D t; +SamplerState s; + +cbuffer CB +{ + float4 v; +} + +float4 w; + +float4 main() : SV_Target +{ + return u + v + w + t.Sample(s, u.xy); +}
\ No newline at end of file diff --git a/tests/reflection/global-type-params.slang.expected b/tests/reflection/global-type-params.slang.expected new file mode 100644 index 000000000..1e3a6aa99 --- /dev/null +++ b/tests/reflection/global-type-params.slang.expected @@ -0,0 +1,127 @@ +result code = 0 +standard error = { +} +standard output = { +{ + "parameters": [ + { + "name": "u", + "binding": {"kind": "uniform", "offset": 0, "size": 16}, + "type": { + "kind": "vector", + "elementCount": 4, + "elementType": { + "kind": "scalar", + "scalarType": "float32" + } + } + }, + { + "name": "arg", + "binding": {"kind": "generic", "index": 0}, + "type": { + "kind": "parameterBlock", + "elementType": { + "kind": "struct", + "name": "S", + "fields": [ + { + "name": "p", + "type": { + "kind": "GenericTypeParameter", + "name": "TParam2" + }, + "binding": {"kind": "generic", "index": 0} + } + ] + } + } + }, + { + "name": "arg1", + "binding": {"kind": "generic", "index": 0}, + "type": { + "kind": "parameterBlock", + "elementType": { + "kind": "GenericTypeParameter", + "name": "TParam" + } + } + }, + { + "name": "t", + "binding": {"kind": "shaderResource", "index": 0}, + "type": { + "kind": "resource", + "baseShape": "texture2D" + } + }, + { + "name": "s", + "binding": {"kind": "samplerState", "index": 0}, + "type": { + "kind": "samplerState" + } + }, + { + "name": "CB", + "binding": {"kind": "constantBuffer", "index": 1}, + "type": { + "kind": "constantBuffer", + "elementType": { + "kind": "struct", + "fields": [ + { + "name": "v", + "type": { + "kind": "vector", + "elementCount": 4, + "elementType": { + "kind": "scalar", + "scalarType": "float32" + } + }, + "binding": {"kind": "uniform", "offset": 0, "size": 16} + } + ] + } + } + }, + { + "name": "w", + "binding": {"kind": "uniform", "offset": 16, "size": 16}, + "type": { + "kind": "vector", + "elementCount": 4, + "elementType": { + "kind": "scalar", + "scalarType": "float32" + } + } + } + ], + "typeParams": + [ + { + "name": "TParam", + constraints: + [ + { + "kind": "Interface", + "name": "IBase" + } + ] + }, + { + "name": "TParam2", + constraints: + [ + { + "kind": "Interface", + "name": "IBase" + } + ] + } + ] +} +} diff --git a/tools/slang-reflection-test/main.cpp b/tools/slang-reflection-test/main.cpp index 2b5477b4a..90be8f5c7 100644 --- a/tools/slang-reflection-test/main.cpp +++ b/tools/slang-reflection-test/main.cpp @@ -117,6 +117,7 @@ static void emitReflectionVarBindingInfoJSON( CASE(SPECIALIZATION_CONSTANT, specializationConstant); CASE(MIXED, mixed); CASE(REGISTER_SPACE, registerSpace); + CASE(GENERIC, generic); #undef CASE default: @@ -287,7 +288,8 @@ static void emitReflectionTypeInfoJSON( PrettyWriter& writer, slang::TypeReflection* type) { - switch( type->getKind() ) + auto kind = type->getKind(); + switch(kind) { case slang::TypeReflection::Kind::SamplerState: write(writer, "\"kind\": \"samplerState\""); @@ -456,6 +458,14 @@ static void emitReflectionTypeInfoJSON( } break; + case slang::TypeReflection::Kind::GenericTypeParameter: + write(writer, "\"kind\": \"GenericTypeParameter\",\n"); + emitReflectionNameInfoJSON(writer, type->getName()); + break; + case slang::TypeReflection::Kind::Interface: + write(writer, "\"kind\": \"Interface\",\n"); + emitReflectionNameInfoJSON(writer, type->getName()); + break; default: assert(!"unhandled case"); break; @@ -555,6 +565,16 @@ static void emitReflectionTypeLayoutInfoJSON( writer, typeLayout->getElementTypeLayout()); break; + case slang::TypeReflection::Kind::GenericTypeParameter: + write(writer, "\"kind\": \"GenericTypeParameter\""); + write(writer, ",\n"); + emitReflectionNameInfoJSON(writer, typeLayout->getName()); + break; + case slang::TypeReflection::Kind::Interface: + write(writer, "\"kind\": \"Interface\",\n"); + write(writer, ",\n"); + emitReflectionNameInfoJSON(writer, typeLayout->getName()); + break; } // TODO: emit size info for types @@ -662,6 +682,33 @@ Range<T> range(T end) return Range<T>(T(0), end); } +static void emitReflectionTypeParamJSON( + PrettyWriter& writer, + slang::TypeParameterReflection* typeParam) +{ + write(writer, "{\n"); + indent(writer); + emitReflectionNameInfoJSON(writer, typeParam->getName()); + write(writer, ",\n"); + write(writer, "constraints: \n"); + write(writer, "[\n"); + indent(writer); + auto constraintCount = typeParam->getConstraintCount(); + for (auto ee : range(constraintCount)) + { + if (ee != 0) write(writer, ",\n"); + write(writer, "{\n"); + indent(writer); + emitReflectionTypeInfoJSON(writer, typeParam->getConstraintByIndex(ee)); + dedent(writer); + write(writer, "\n}"); + } + dedent(writer); + write(writer, "\n]"); + dedent(writer); + write(writer, "\n}"); +} + static void emitReflectionEntryPointJSON( PrettyWriter& writer, slang::EntryPointReflection* entryPoint) @@ -700,7 +747,6 @@ static void emitReflectionEntryPointJSON( dedent(writer); write(writer, "\n]"); } - if (entryPoint->usesAnySampleRateInput()) { write(writer, ",\n\"usesAnySampleRateInput\": true"); @@ -763,6 +809,22 @@ static void emitReflectionJSON( write(writer, "\n]"); } + auto genParamCount = programReflection->getTypeParameterCount(); + if (genParamCount) + { + write(writer, ",\n\"typeParams\":\n"); + write(writer, "[\n"); + indent(writer); + for (auto ee : range(genParamCount)) + { + if (ee != 0) write(writer, ",\n"); + + auto typeParam = programReflection->getTypeParameterByIndex(ee); + emitReflectionTypeParamJSON(writer, typeParam); + } + dedent(writer); + write(writer, "\n]"); + } dedent(writer); write(writer, "\n}\n"); } |
