From 7c8527d20e433c3a10736136d31e4cd882a3baaa Mon Sep 17 00:00:00 2001 From: jsmall-nvidia Date: Fri, 4 Oct 2019 09:46:03 -0400 Subject: IR types for subset of Attributes (#1067) * IROutputControlPointsDecoration * IROutputTopologyDecoration * IRPartitioningDecoration * IRDomainDecoration * Use IRPatchConstantDecoration alone for hlsl output. * IRMaxVertexCountDecoration * IRInstanceDecoration * Removed _emitHLSLAttributeSingleString and _emitHLSLAttributeSingleInt Removed GLSLBindingAttribute and just use NumThreadsAttribute * Added IRNumThreadsDecoration. * Added IRNumThreadsDecoration * Fix build problem on x86. Improve diagnostic text based on review. --- source/slang/slang-check.cpp | 367 ++++++++++++++++++----------------- source/slang/slang-diagnostic-defs.h | 1 + source/slang/slang-emit-c-like.cpp | 11 ++ source/slang/slang-emit-c-like.h | 8 + source/slang/slang-emit-cpp.cpp | 59 +++--- source/slang/slang-emit-cpp.h | 8 +- source/slang/slang-emit-glsl.cpp | 25 +-- source/slang/slang-emit-hlsl.cpp | 127 ++++-------- source/slang/slang-emit-hlsl.h | 9 +- source/slang/slang-ir-inst-defs.h | 8 + source/slang/slang-ir-insts.h | 59 ++++++ source/slang/slang-lower-to-ir.cpp | 73 +++++++ source/slang/slang-modifier-defs.h | 5 - source/slang/slang-parser.cpp | 46 ++++- source/slang/slang-reflection.cpp | 21 -- 15 files changed, 473 insertions(+), 354 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index cec6b02a2..b5069a02f 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -2781,236 +2781,253 @@ namespace Slang bool validateAttribute(RefPtr attr, AttributeDecl* attribClassDecl) { - if(auto numThreadsAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 3); - auto xVal = checkConstantIntVal(attr->args[0]); - auto yVal = checkConstantIntVal(attr->args[1]); - auto zVal = checkConstantIntVal(attr->args[2]); + if(auto numThreadsAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 3); - if(!xVal) return false; - if(!yVal) return false; - if(!zVal) return false; + int32_t values[3]; - numThreadsAttr->x = (int32_t) xVal->value; - numThreadsAttr->y = (int32_t) yVal->value; - numThreadsAttr->z = (int32_t) zVal->value; - } - else if (auto bindingAttr = as(attr)) + for (int i = 0; i < 3; ++i) { - // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) - // Must have 2 int parameters. Ideally this would all be checked from the specification - // in core.meta.slang, but that's not completely implemented. So for now we check here. - if (attr->args.getCount() != 2) - { - return false; - } + int32_t value = 1; - // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here - // to make sure they both are - auto binding = checkConstantIntVal(attr->args[0]); - auto set = checkConstantIntVal(attr->args[1]); - - if (binding == nullptr || set == nullptr) + auto arg = attr->args[i]; + if (arg) { - return false; + auto intValue = checkConstantIntVal(arg); + if (!intValue) + { + return false; + } + if (intValue->value < 1) + { + getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->value); + return false; + } + value = int32_t(intValue->value); } - - bindingAttr->binding = int32_t(binding->value); - bindingAttr->set = int32_t(set->value); + values[i] = value; } - else if (auto maxVertexCountAttr = as(attr)) + + numThreadsAttr->x = values[0]; + numThreadsAttr->y = values[1]; + numThreadsAttr->z = values[2]; + } + else if (auto bindingAttr = as(attr)) + { + // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) + // Must have 2 int parameters. Ideally this would all be checked from the specification + // in core.meta.slang, but that's not completely implemented. So for now we check here. + if (attr->args.getCount() != 2) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); + return false; + } - if(!val) return false; + // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here + // to make sure they both are + auto binding = checkConstantIntVal(attr->args[0]); + auto set = checkConstantIntVal(attr->args[1]); - maxVertexCountAttr->value = (int32_t)val->value; - } - else if(auto instanceAttr = as(attr)) + if (binding == nullptr || set == nullptr) { - SLANG_ASSERT(attr->args.getCount() == 1); - auto val = checkConstantIntVal(attr->args[0]); + return false; + } + + bindingAttr->binding = int32_t(binding->value); + bindingAttr->set = int32_t(set->value); + } + else if (auto maxVertexCountAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); - if(!val) return false; + if(!val) return false; - instanceAttr->value = (int32_t)val->value; - } - else if(auto entryPointAttr = as(attr)) - { - SLANG_ASSERT(attr->args.getCount() == 1); + maxVertexCountAttr->value = (int32_t)val->value; + } + else if(auto instanceAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); - String stageName; - if(!checkLiteralStringVal(attr->args[0], &stageName)) - { - return false; - } + if(!val) return false; - auto stage = findStageByName(stageName); - if(stage == Stage::Unknown) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName); - } + instanceAttr->value = (int32_t)val->value; + } + else if(auto entryPointAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); - entryPointAttr->stage = stage; - } - else if ((as(attr)) || - (as(attr)) || - (as(attr)) || - (as(attr)) || - (as(attr))) - { - // Let it go thru iff single string attribute - if (!hasStringArgs(attr, 1)) - { - getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name); - } - } - else if (as(attr)) + String stageName; + if(!checkLiteralStringVal(attr->args[0], &stageName)) { - // Let it go thru iff single integral attribute - if (!hasIntArgs(attr, 1)) - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); - } + return false; } - else if (as(attr)) + + auto stage = findStageByName(stageName); + if(stage == Stage::Unknown) { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); + getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName); } - else if (as(attr)) + + entryPointAttr->stage = stage; + } + else if ((as(attr)) || + (as(attr)) || + (as(attr)) || + (as(attr)) || + (as(attr))) + { + // Let it go thru iff single string attribute + if (!hasStringArgs(attr, 1)) { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); + getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name); } - else if (as(attr)) + } + else if (as(attr)) + { + // Let it go thru iff single integral attribute + if (!hasIntArgs(attr, 1)) { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); } - else if (auto attrUsageAttr = as(attr)) + } + else if (as(attr)) + { + // Has no args + SLANG_ASSERT(attr->args.getCount() == 0); + } + else if (as(attr)) + { + // Has no args + SLANG_ASSERT(attr->args.getCount() == 0); + } + else if (as(attr)) + { + // Has no args + SLANG_ASSERT(attr->args.getCount() == 0); + } + else if (auto attrUsageAttr = as(attr)) + { + uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; + if (attr->args.getCount() == 1) { - uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; - if (attr->args.getCount() == 1) + RefPtr outIntVal; + if (auto cInt = checkConstantEnumVal(attr->args[0])) { - RefPtr outIntVal; - if (auto cInt = checkConstantEnumVal(attr->args[0])) - { - targetClassId = (uint32_t)(cInt->value); - } - else - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); - return false; - } + targetClassId = (uint32_t)(cInt->value); } - if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) + else { - getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); return false; } } - else if (auto unrollAttr = as(attr)) + if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) { - // Check has an argument. We need this because default behavior is to give an error - // if an attribute has arguments, but not handled explicitly (and the default param will come through - // as 1 arg if nothing is specified) - SLANG_ASSERT(attr->args.getCount() == 1); + getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); + return false; } - else if (auto userDefAttr = as(attr)) + } + else if (auto unrollAttr = as(attr)) + { + // Check has an argument. We need this because default behavior is to give an error + // if an attribute has arguments, but not handled explicitly (and the default param will come through + // as 1 arg if nothing is specified) + SLANG_ASSERT(attr->args.getCount() == 1); + } + else if (auto userDefAttr = as(attr)) + { + // check arguments against attribute parameters defined in attribClassDecl + Index paramIndex = 0; + auto params = attribClassDecl->getMembersOfType(); + for (auto paramDecl : params) { - // check arguments against attribute parameters defined in attribClassDecl - Index paramIndex = 0; - auto params = attribClassDecl->getMembersOfType(); - for (auto paramDecl : params) + if (paramIndex < attr->args.getCount()) { - if (paramIndex < attr->args.getCount()) + auto & arg = attr->args[paramIndex]; + bool typeChecked = false; + if (auto basicType = as(paramDecl->getType())) { - auto & arg = attr->args[paramIndex]; - bool typeChecked = false; - if (auto basicType = as(paramDecl->getType())) + if (basicType->baseType == BaseType::Int) { - if (basicType->baseType == BaseType::Int) + if (auto cint = checkConstantIntVal(arg)) { - if (auto cint = checkConstantIntVal(arg)) - { - attr->intArgVals[(uint32_t)paramIndex] = cint; - } - typeChecked = true; + attr->intArgVals[(uint32_t)paramIndex] = cint; } - } - if (!typeChecked) - { - arg = CheckExpr(arg); - arg = coerce(paramDecl->getType(), arg); + typeChecked = true; } } - paramIndex++; - } - if (params.getCount() < attr->args.getCount()) - { - getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount()); - } - else if (params.getCount() > attr->args.getCount()) - { - getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); + if (!typeChecked) + { + arg = CheckExpr(arg); + arg = coerce(paramDecl->getType(), arg); + } } + paramIndex++; } - else if (auto formatAttr = as(attr)) + if (params.getCount() < attr->args.getCount()) { - SLANG_ASSERT(attr->args.getCount() == 1); - - String formatName; - if(!checkLiteralStringVal(attr->args[0], &formatName)) - { - return false; - } - - ImageFormat format = ImageFormat::unknown; - if(!findImageFormatByName(formatName.getBuffer(), &format)) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); - } + getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount()); + } + else if (params.getCount() > attr->args.getCount()) + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); + } + } + else if (auto formatAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); - formatAttr->format = format; + String formatName; + if(!checkLiteralStringVal(attr->args[0], &formatName)) + { + return false; } - else if (auto allowAttr = as(attr)) + + ImageFormat format = ImageFormat::unknown; + if(!findImageFormatByName(formatName.getBuffer(), &format)) { - SLANG_ASSERT(attr->args.getCount() == 1); + getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName); + } - String diagnosticName; - if(!checkLiteralStringVal(attr->args[0], &diagnosticName)) - { - return false; - } + formatAttr->format = format; + } + else if (auto allowAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); - auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice()); - if(!diagnosticInfo) - { - getSink()->diagnose(attr->args[0], Diagnostics::unknownDiagnosticName, diagnosticName); - } + String diagnosticName; + if(!checkLiteralStringVal(attr->args[0], &diagnosticName)) + { + return false; + } - allowAttr->diagnostic = diagnosticInfo; + auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice()); + if(!diagnosticInfo) + { + getSink()->diagnose(attr->args[0], Diagnostics::unknownDiagnosticName, diagnosticName); + } + + allowAttr->diagnostic = diagnosticInfo; + } + else + { + if(attr->args.getCount() == 0) + { + // If the attribute took no arguments, then we will + // assume it is valid as written. } else { - if(attr->args.getCount() == 0) - { - // If the attribute took no arguments, then we will - // assume it is valid as written. - } - else - { - // We should be special-casing the checking of any attribute - // with a non-zero number of arguments. - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute"); - return false; - } + // We should be special-casing the checking of any attribute + // with a non-zero number of arguments. + SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute"); + return false; } + } - return true; + return true; } RefPtr checkAttribute( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c1c207fd7..067a65159 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -274,6 +274,7 @@ DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0 DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'") DIAGNOSTIC(31101, Error, unknownDiagnosticName, "unknown diagnostic '$0'") +DIAGNOSTIC(31102, Error, nonPositiveNumThreads, "expected a positive integer in 'numthreads' attribute, got '$0'") DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index a19823055..281da3ef9 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -203,6 +203,17 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) } } + +/* static */IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]) +{ + IRNumThreadsDecoration* decor = func->findDecoration(); + for (int i = 0; i < 3; ++i) + { + outNumThreads[i] = decor ? Int(GetIntVal(decor->getOperand(i))) : 1; + } + return decor; +} + void CLikeSourceEmitter::_emitArrayType(IRArrayType* arrayType, EDeclarator* declarator) { EDeclarator arrayDeclarator; diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index ee906010b..9c62855aa 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -43,6 +43,11 @@ public: // explicitly in the layout data. StructTypeLayout* globalStructLayout = nullptr; }; + + enum + { + kThreadGroupAxisCount = 3, + }; /// To simplify cases enum class SourceStyle @@ -318,6 +323,9 @@ public: /// Returns an empty slice if not a built in type static UnownedStringSlice getDefaultBuiltinTypeName(IROp op); + /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 + static IRNumThreadsDecoration* getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); + protected: virtual void emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling = "register") { SLANG_UNUSED(inst); SLANG_UNUSED(uniformSemanticSpelling); } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 1628c6770..1f38512b4 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1916,23 +1916,15 @@ void CPPSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointLa { case Stage::Compute: { - static const UInt kAxisCount = 3; - UInt sizeAlongAxis[kAxisCount]; - - // TODO: this is kind of gross because we are using a public - // reflection API function, rather than some kind of internal - // utility it forwards to... - spReflectionEntryPoint_getComputeThreadGroupSize( - (SlangReflectionEntryPoint*)entryPointLayout, - kAxisCount, - &sizeAlongAxis[0]); - + Int numThreads[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, numThreads); + // TODO(JS): We might want to store this information such that it can be used to execute m_writer->emit("// [numthreads("); - for (int ii = 0; ii < 3; ++ii) + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) { if (ii != 0) m_writer->emit(", "); - m_writer->emit(sizeAlongAxis[ii]); + m_writer->emit(numThreads[ii]); } m_writer->emit(")]\n"); break; @@ -2504,16 +2496,16 @@ struct AxisWithSize bool operator<(const ThisType& rhs) const { return size < rhs.size || (size == rhs.size && axis < rhs.axis); } int axis; - UInt size; + Int size; }; } // anonymous -static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List& out) +static void _calcAxisOrder(const Int sizeAlongAxis[CLikeSourceEmitter::kThreadGroupAxisCount], bool allowSingle, List& out) { out.clear(); // Add in order z,y,x, so by default (if we don't sort), x will be the inner loop - for (int i = 3 - 1; i >= 0; --i) + for (int i = CLikeSourceEmitter::kThreadGroupAxisCount - 1; i >= 0; --i) { if (allowSingle || sizeAlongAxis[i] > 1) { @@ -2529,7 +2521,7 @@ static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List axes; _calcAxisOrder(sizeAlongAxis, false, axes); @@ -2563,7 +2555,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const S } } -void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName) +void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName) { List axes; _calcAxisOrder(sizeAlongAxis, true, axes); @@ -2635,13 +2627,13 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], co m_writer->emit("}\n"); } } -void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName) +void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName) { StringBuilder builder; m_writer->emit("{\n"); m_writer->indent(); - for (int i = 0; i < 3; ++i) + for (int i = 0; i < kThreadGroupAxisCount; ++i) { builder.Clear(); const char elem[2] = { s_elemNames[i], 0 }; @@ -2650,7 +2642,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const Un { builder << " + " << addName << "." << elem; } - if (i < 3 - 1) + if (i < kThreadGroupAxisCount - 1) { builder << ","; } @@ -2821,14 +2813,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID. - static const UInt kAxisCount = 3; - UInt sizeAlongAxis[kAxisCount]; - - // TODO: this is kind of gross because we are using a public - // reflection API function, rather than some kind of internal - // utility it forwards to... - spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]); - + Int groupThreadSize[kThreadGroupAxisCount]; + getComputeThreadGroupSize(func, groupThreadSize); + String funcName = getFuncName(func); { @@ -2842,7 +2829,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) { m_writer->emit("context.groupDispatchThreadID = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice()); } if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) { @@ -2851,7 +2838,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) // Emit dispatchThreadID m_writer->emit("context.dispatchThreadID = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID")); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID")); m_writer->emit("context._"); m_writer->emit(funcName); @@ -2871,7 +2858,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); m_writer->emit("const uint3 start = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID) { @@ -2884,7 +2871,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) } m_writer->emit("context.dispatchThreadID = start;\n"); - _emitEntryPointGroup(sizeAlongAxis, funcName); + _emitEntryPointGroup(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); } @@ -2893,11 +2880,11 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) _emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput")); m_writer->emit("const uint3 start = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice()); m_writer->emit("const uint3 end = "); - _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice()); + _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice()); - _emitEntryPointGroupRange(sizeAlongAxis, funcName); + _emitEntryPointGroupRange(groupThreadSize, funcName); _emitEntryPointDefinitionEnd(func); } } diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index 0e182818d..af7aef271 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -271,15 +271,15 @@ protected: void _emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName, const UnownedStringSlice& varyingTypeName); void _emitEntryPointDefinitionEnd(IRFunc* func); - void _emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName); - void _emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName); + void _emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName); + void _emitEntryPointGroupRange(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName); - void _emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName); + void _emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName); Dictionary m_intrinsicNameMap; Dictionary m_typeNameMap; - /* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsice. + /* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic. That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms of other passes. diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index a1c7b9170..e43d41d5c 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -661,20 +661,12 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL { case Stage::Compute: { - static const UInt kAxisCount = 3; - UInt sizeAlongAxis[kAxisCount]; - - // TODO: this is kind of gross because we are using a public - // reflection API function, rather than some kind of internal - // utility it forwards to... - spReflectionEntryPoint_getComputeThreadGroupSize( - (SlangReflectionEntryPoint*)entryPointLayout, - kAxisCount, - &sizeAlongAxis[0]); + Int sizeAlongAxis[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis); m_writer->emit("layout("); char const* axes[] = { "x", "y", "z" }; - for (int ii = 0; ii < 3; ++ii) + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) { if (ii != 0) m_writer->emit(", "); m_writer->emit("local_size_"); @@ -687,16 +679,19 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL break; case Stage::Geometry: { - if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) + if (auto decor = irFunc->findDecoration()) { + auto count = GetIntVal(decor->getCount()); m_writer->emit("layout(max_vertices = "); - m_writer->emit(attrib->value); + m_writer->emit(Int(count)); m_writer->emit(") out;\n"); } - if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) + + if (auto decor = irFunc->findDecoration()) { + auto count = GetIntVal(decor->getCount()); m_writer->emit("layout(invocations = "); - m_writer->emit(attrib->value); + m_writer->emit(Int(count)); m_writer->emit(") in;\n"); } diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 4abd692f8..a5c4ae088 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -10,58 +10,29 @@ namespace Slang { - -void HLSLSourceEmitter::_emitHLSLAttributeSingleString(const char* name, FuncDecl* entryPoint, Attribute* attrib) +void HLSLSourceEmitter::_emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val) { - assert(attrib); - - attrib->args.getCount(); - if (attrib->args.getCount() != 1) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects single parameter"); - return; - } - - Expr* expr = attrib->args[0]; - - auto stringLitExpr = as(expr); - if (!stringLitExpr) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute parameter expecting to be a string "); - return; - } + SLANG_UNUSED(entryPoint); + assert(val); m_writer->emit("["); m_writer->emit(name); m_writer->emit("(\""); - m_writer->emit(stringLitExpr->value); + m_writer->emit(val->getStringSlice()); m_writer->emit("\")]\n"); } -void HLSLSourceEmitter::_emitHLSLAttributeSingleInt(const char* name, FuncDecl* entryPoint, Attribute* attrib) +void HLSLSourceEmitter::_emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val) { - assert(attrib); + SLANG_UNUSED(entryPoint); + SLANG_ASSERT(val); - attrib->args.getCount(); - if (attrib->args.getCount() != 1) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects single parameter"); - return; - } - - Expr* expr = attrib->args[0]; - - auto intLitExpr = as(expr); - if (!intLitExpr) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects an int"); - return; - } + auto intVal = GetIntVal(val); m_writer->emit("["); m_writer->emit(name); m_writer->emit("("); - m_writer->emit(intLitExpr->value); + m_writer->emit(intVal); m_writer->emit(")]\n"); } @@ -272,19 +243,11 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint { case Stage::Compute: { - static const UInt kAxisCount = 3; - UInt sizeAlongAxis[kAxisCount]; - - // TODO: this is kind of gross because we are using a public - // reflection API function, rather than some kind of internal - // utility it forwards to... - spReflectionEntryPoint_getComputeThreadGroupSize( - (SlangReflectionEntryPoint*)entryPointLayout, - kAxisCount, - &sizeAlongAxis[0]); + Int sizeAlongAxis[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis); m_writer->emit("[numthreads("); - for (int ii = 0; ii < 3; ++ii) + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) { if (ii != 0) m_writer->emit(", "); m_writer->emit(sizeAlongAxis[ii]); @@ -294,29 +257,30 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint break; case Stage::Geometry: { - if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) + if (auto decor = irFunc->findDecoration()) { + auto count = GetIntVal(decor->getCount()); m_writer->emit("[maxvertexcount("); - m_writer->emit(attrib->value); + m_writer->emit(Int(count)); m_writer->emit(")]\n"); } - if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier()) + + if (auto decor = irFunc->findDecoration()) { + auto count = GetIntVal(decor->getCount()); m_writer->emit("[instance("); - m_writer->emit(attrib->value); + m_writer->emit(Int(count)); m_writer->emit(")]\n"); } break; } case Stage::Domain: { - FuncDecl* entryPoint = entryPointLayout->entryPoint; /* [domain("isoline")] */ - if (auto attrib = entryPoint->FindModifier()) + if (auto decor = irFunc->findDecoration()) { - _emitHLSLAttributeSingleString("domain", entryPoint, attrib); + _emitHLSLDecorationSingleString("domain", irFunc, decor->getDomain()); } - break; } case Stage::Hull: @@ -324,32 +288,38 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint // Lists these are only attributes for hull shader // https://docs.microsoft.com/en-us/windows/desktop/direct3d11/direct3d-11-advanced-stages-hull-shader-design - FuncDecl* entryPoint = entryPointLayout->entryPoint; - /* [domain("isoline")] */ - if (auto attrib = entryPoint->FindModifier()) + if (auto decor = irFunc->findDecoration()) { - _emitHLSLAttributeSingleString("domain", entryPoint, attrib); + _emitHLSLDecorationSingleString("domain", irFunc, decor->getDomain()); } + /* [domain("partitioning")] */ - if (auto attrib = entryPoint->FindModifier()) + if (auto decor = irFunc->findDecoration()) { - _emitHLSLAttributeSingleString("partitioning", entryPoint, attrib); + _emitHLSLDecorationSingleString("partitioning", irFunc, decor->getPartitioning()); } + /* [outputtopology("line")] */ - if (auto attrib = entryPoint->FindModifier()) + if (auto decor = irFunc->findDecoration()) { - _emitHLSLAttributeSingleString("outputtopology", entryPoint, attrib); + _emitHLSLDecorationSingleString("outputtopology", irFunc, decor->getTopology()); } + /* [outputcontrolpoints(4)] */ - if (auto attrib = entryPoint->FindModifier()) + if (auto decor = irFunc->findDecoration()) { - _emitHLSLAttributeSingleInt("outputcontrolpoints", entryPoint, attrib); + _emitHLSLDecorationSingleInt("outputcontrolpoints", irFunc, decor->getControlPointCount()); } + /* [patchconstantfunc("HSConst")] */ - if (auto attrib = entryPoint->FindModifier()) + if (auto decor = irFunc->findDecoration()) { - _emitHLSLFuncDeclPatchConstantFuncAttribute(irFunc, entryPoint, attrib); + const String irName = getName(decor->getFunc()); + + m_writer->emit("[patchconstantfunc(\""); + m_writer->emit(irName); + m_writer->emit("\")]\n"); } break; @@ -422,25 +392,6 @@ void HLSLSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType) m_writer->emit(" >"); } -void HLSLSourceEmitter::_emitHLSLFuncDeclPatchConstantFuncAttribute(IRFunc* irFunc, FuncDecl* entryPoint, PatchConstantFuncAttribute* attrib) -{ - SLANG_UNUSED(attrib); - - auto irPatchFunc = irFunc->findDecoration(); - assert(irPatchFunc); - if (!irPatchFunc) - { - SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Unable to find [patchConstantFunc(...)] decoration"); - return; - } - - const String irName = getName(irPatchFunc->getFunc()); - - m_writer->emit("[patchconstantfunc(\""); - m_writer->emit(irName); - m_writer->emit("\")]\n"); -} - void HLSLSourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling) { auto layout = getVarLayout(inst); diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index 6b4dfaa5c..b6b5fa740 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -50,12 +50,9 @@ protected: void _emitHLSLTextureType(IRTextureTypeBase* texType); - void _emitHLSLFuncDeclPatchConstantFuncAttribute(IRFunc* irFunc, FuncDecl* entryPoint, PatchConstantFuncAttribute* attrib); - - void _emitHLSLAttributeSingleString(const char* name, FuncDecl* entryPoint, Attribute* attrib); - - void _emitHLSLAttributeSingleInt(const char* name, FuncDecl* entryPoint, Attribute* attrib); - + void _emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val); + void _emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val); + }; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 49b6138ed..b586aa5bc 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -414,6 +414,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(PreciseDecoration, precise, 0, 0) INST(PatchConstantFuncDecoration, patchConstantFunc, 1, 0) + INST(OutputControlPointsDecoration, outputControlPoints, 1, 0) + INST(OutputTopologyDecoration, outputTopology, 1, 0) + INST(PartitioningDecoration, partioning, 1, 0) + INST(DomainDecoration, domain, 1, 0) + INST(MaxVertexCountDecoration, maxVertexCount, 1, 0) + INST(InstanceDecoration, instance, 1, 0) + INST(NumThreadsDecoration, numThreads, 3, 0) + /// An `[entryPoint]` decoration marks a function that represents a shader entry point. /// Also used in some scenarios mark parameters that are moved from entry point parameters to global params as coming from the entry point. INST(EntryPointDecoration, entryPoint, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2e2168b1a..9dc80beac 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -228,6 +228,65 @@ IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration) IR_SIMPLE_DECORATION(GloballyCoherentDecoration) IR_SIMPLE_DECORATION(PreciseDecoration) + +struct IROutputControlPointsDecoration : IRDecoration +{ + enum { kOp = kIROp_OutputControlPointsDecoration }; + IR_LEAF_ISA(OutputControlPointsDecoration) + + IRIntLit* getControlPointCount() { return cast(getOperand(0)); } +}; + +struct IROutputTopologyDecoration : IRDecoration +{ + enum { kOp = kIROp_OutputTopologyDecoration }; + IR_LEAF_ISA(OutputTopologyDecoration) + + IRStringLit* getTopology() { return cast(getOperand(0)); } +}; + +struct IRPartitioningDecoration : IRDecoration +{ + enum { kOp = kIROp_PartitioningDecoration }; + IR_LEAF_ISA(PartitioningDecoration) + + IRStringLit* getPartitioning() { return cast(getOperand(0)); } +}; + +struct IRDomainDecoration : IRDecoration +{ + enum { kOp = kIROp_DomainDecoration }; + IR_LEAF_ISA(DomainDecoration) + + IRStringLit* getDomain() { return cast(getOperand(0)); } +}; + +struct IRMaxVertexCountDecoration : IRDecoration +{ + enum { kOp = kIROp_MaxVertexCountDecoration }; + IR_LEAF_ISA(MaxVertexCountDecoration) + + IRIntLit* getCount() { return cast(getOperand(0)); } +}; + +struct IRInstanceDecoration : IRDecoration +{ + enum { kOp = kIROp_InstanceDecoration }; + IR_LEAF_ISA(InstanceDecoration) + + IRIntLit* getCount() { return cast(getOperand(0)); } +}; + +struct IRNumThreadsDecoration : IRDecoration +{ + enum { kOp = kIROp_NumThreadsDecoration }; + IR_LEAF_ISA(NumThreadsDecoration) + + IRIntLit* getX() { return cast(getOperand(0)); } + IRIntLit* getY() { return cast(getOperand(1)); } + IRIntLit* getZ() { return cast(getOperand(2)); } +}; + /// A decoration that marks a value as having linkage. /// /// A value with linkage is either exported from its module, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1edd7d331..a622c9802 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4075,6 +4075,8 @@ struct DeclLoweringVisitor : DeclVisitor { IRGenContext* context; + DiagnosticSink* getSink() { return context->getSink(); } + IRBuilder* getBuilder() { return context->irBuilder; @@ -5573,6 +5575,27 @@ struct DeclLoweringVisitor : DeclVisitor } } + IRIntLit* _getIntLitFromAttribute(IRBuilder* builder, Attribute* attrib) + { + attrib->args.getCount(); + SLANG_ASSERT(attrib->args.getCount() ==1); + Expr* expr = attrib->args[0]; + auto intLitExpr = as(expr); + SLANG_ASSERT(intLitExpr); + return as(builder->getIntValue(builder->getIntType(), intLitExpr->value)); + } + + IRStringLit* _getStringLitFromAttribute(IRBuilder* builder, Attribute* attrib) + { + attrib->args.getCount(); + SLANG_ASSERT(attrib->args.getCount() == 1); + Expr* expr = attrib->args[0]; + + auto stringLitExpr = as(expr); + SLANG_ASSERT(stringLitExpr); + return as(builder->getStringValue(stringLitExpr->value.getUnownedSlice())); + } + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) { // We are going to use a nested builder, because we will @@ -5912,6 +5935,32 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken))); } + if (auto attr = decl->FindModifier()) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); + getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); + } + + if(auto attr = decl->FindModifier()) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); + getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); + } + + if(auto attr = decl->FindModifier()) + { + auto builder = getBuilder(); + IRType* intType = builder->getIntType(); + + IRInst* operands[3] = { + builder->getIntValue(intType, attr->x), + builder->getIntValue(intType, attr->y), + builder->getIntValue(intType, attr->z) + }; + + builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + } + if(decl->FindModifier()) { getBuilder()->addSimpleDecoration(irFunc); @@ -5922,6 +5971,30 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addSimpleDecoration(irFunc); } + if (auto attr = decl->FindModifier()) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); + getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); + } + + if (auto attr = decl->FindModifier()) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); + getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); + } + + if (auto attr = decl->FindModifier()) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); + getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); + } + + if (auto attr = decl->FindModifier()) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); + getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); + } + // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list diff --git a/source/slang/slang-modifier-defs.h b/source/slang/slang-modifier-defs.h index 754e7e44a..5ac61b991 100644 --- a/source/slang/slang-modifier-defs.h +++ b/source/slang/slang-modifier-defs.h @@ -131,11 +131,6 @@ SIMPLE_SYNTAX_CLASS(GLSLUnparsedLayoutModifier , GLSLLayoutModifier) SIMPLE_SYNTAX_CLASS(GLSLConstantIDLayoutModifier , GLSLParsedLayoutModifier) SIMPLE_SYNTAX_CLASS(GLSLLocationLayoutModifier , GLSLParsedLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeLayoutModifier, GLSLUnparsedLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeXLayoutModifier, GLSLLocalSizeLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeYLayoutModifier, GLSLLocalSizeLayoutModifier) -SIMPLE_SYNTAX_CLASS(GLSLLocalSizeZLayoutModifier, GLSLLocalSizeLayoutModifier) - // A catch-all for single-keyword modifiers SIMPLE_SYNTAX_CLASS(SimpleModifier, Modifier) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index fffbf2c62..dad84a4b6 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -4450,6 +4450,8 @@ namespace Slang { ModifierListBuilder listBuilder; + RefPtr numThreadsAttrib; + listBuilder.add(new GLSLLayoutModifierGroupBegin()); parser->ReadToken(TokenType::LParent); @@ -4458,7 +4460,41 @@ namespace Slang auto nameAndLoc = expectIdentifier(parser); const String& nameText = nameAndLoc.name->text; - if (nameText == "binding" || + const char localSizePrefix[] = "local_size_"; + + int localSizeIndex = -1; + if (nameText.startsWith(localSizePrefix) && nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) + { + char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; + localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; + } + + if (localSizeIndex >= 0) + { + if (!numThreadsAttrib) + { + numThreadsAttrib = new UncheckedAttribute; + numThreadsAttrib->args.setCount(3); + + // Just mark the loc and name from the first in the list + numThreadsAttrib->name = getName(parser, "numthreads"); + numThreadsAttrib->loc = nameAndLoc.loc; + numThreadsAttrib->scope = parser->currentScope; + } + + if (AdvanceIf(parser, TokenType::OpAssign)) + { + auto expr = parseAtomicExpr(parser); + //SLANG_ASSERT(expr); + if (!expr) + { + return nullptr; + } + + numThreadsAttrib->args[localSizeIndex] = expr; + } + } + else if (nameText == "binding" || nameText == "set") { GLSLBindingAttribute* attr = listBuilder.find(); @@ -4499,9 +4535,6 @@ namespace Slang CASE(shaderRecordNV, ShaderRecordAttribute) CASE(constant_id, GLSLConstantIDLayoutModifier) CASE(location, GLSLLocationLayoutModifier) - CASE(local_size_x, GLSLLocalSizeXLayoutModifier) - CASE(local_size_y, GLSLLocalSizeYLayoutModifier) - CASE(local_size_z, GLSLLocalSizeZLayoutModifier) { modifier = new GLSLUnparsedLayoutModifier(); } @@ -4528,6 +4561,11 @@ namespace Slang parser->ReadToken(TokenType::Comma); } + if (numThreadsAttrib) + { + listBuilder.add(numThreadsAttrib); + } + listBuilder.add(new GLSLLayoutModifierGroupEnd()); return listBuilder.getFirst(); diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp index a6015b31c..f2c4973ed 100644 --- a/source/slang/slang-reflection.cpp +++ b/source/slang/slang-reflection.cpp @@ -1267,27 +1267,6 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( sizeAlongAxis[1] = numThreadsAttribute->y; sizeAlongAxis[2] = numThreadsAttribute->z; } - else - { - // Fall back to the GLSL case, which requires a search over global-scope declarations - // to look for as with the `local_size_*` qualifier - auto module = as(entryPointFunc.getDecl()->ParentDecl); - if (module) - { - for (auto dd : module->Members) - { - for (auto mod : dd->GetModifiersOfType()) - { - if (auto xMod = as(mod)) - sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken); - else if (auto yMod = as(mod)) - sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken); - else if (auto zMod = as(mod)) - sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken); - } - } - } - } // -- cgit v1.2.3