summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2019-10-04 09:46:03 -0400
committerGitHub <noreply@github.com>2019-10-04 09:46:03 -0400
commit7c8527d20e433c3a10736136d31e4cd882a3baaa (patch)
tree44032051a4d76c8773b8a503dae14d9c8c9e786d /source
parent0bc7d9b0aeb77c40befeb3618240a065374216a1 (diff)
IR types for subset of Attributes (#1067)
* IROutputControlPointsDecoration * IROutputTopologyDecoration * IRPartitioningDecoration * IRDomainDecoration * Use IRPatchConstantDecoration alone for hlsl output. * IRMaxVertexCountDecoration * IRInstanceDecoration * Removed _emitHLSLAttributeSingleString and _emitHLSLAttributeSingleInt Removed GLSLBindingAttribute and just use NumThreadsAttribute * Added IRNumThreadsDecoration. * Added IRNumThreadsDecoration * Fix build problem on x86. Improve diagnostic text based on review.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check.cpp367
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit-c-like.cpp11
-rw-r--r--source/slang/slang-emit-c-like.h8
-rw-r--r--source/slang/slang-emit-cpp.cpp59
-rw-r--r--source/slang/slang-emit-cpp.h8
-rw-r--r--source/slang/slang-emit-glsl.cpp25
-rw-r--r--source/slang/slang-emit-hlsl.cpp127
-rw-r--r--source/slang/slang-emit-hlsl.h9
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h59
-rw-r--r--source/slang/slang-lower-to-ir.cpp73
-rw-r--r--source/slang/slang-modifier-defs.h5
-rw-r--r--source/slang/slang-parser.cpp46
-rw-r--r--source/slang/slang-reflection.cpp21
15 files changed, 473 insertions, 354 deletions
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp
index cec6b02a2..b5069a02f 100644
--- a/source/slang/slang-check.cpp
+++ b/source/slang/slang-check.cpp
@@ -2781,236 +2781,253 @@ namespace Slang
bool validateAttribute(RefPtr<Attribute> attr, AttributeDecl* attribClassDecl)
{
- if(auto numThreadsAttr = as<NumThreadsAttribute>(attr))
- {
- SLANG_ASSERT(attr->args.getCount() == 3);
- auto xVal = checkConstantIntVal(attr->args[0]);
- auto yVal = checkConstantIntVal(attr->args[1]);
- auto zVal = checkConstantIntVal(attr->args[2]);
+ if(auto numThreadsAttr = as<NumThreadsAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 3);
- if(!xVal) return false;
- if(!yVal) return false;
- if(!zVal) return false;
+ int32_t values[3];
- numThreadsAttr->x = (int32_t) xVal->value;
- numThreadsAttr->y = (int32_t) yVal->value;
- numThreadsAttr->z = (int32_t) zVal->value;
- }
- else if (auto bindingAttr = as<GLSLBindingAttribute>(attr))
+ for (int i = 0; i < 3; ++i)
{
- // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding)
- // Must have 2 int parameters. Ideally this would all be checked from the specification
- // in core.meta.slang, but that's not completely implemented. So for now we check here.
- if (attr->args.getCount() != 2)
- {
- return false;
- }
+ int32_t value = 1;
- // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here
- // to make sure they both are
- auto binding = checkConstantIntVal(attr->args[0]);
- auto set = checkConstantIntVal(attr->args[1]);
-
- if (binding == nullptr || set == nullptr)
+ auto arg = attr->args[i];
+ if (arg)
{
- return false;
+ auto intValue = checkConstantIntVal(arg);
+ if (!intValue)
+ {
+ return false;
+ }
+ if (intValue->value < 1)
+ {
+ getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->value);
+ return false;
+ }
+ value = int32_t(intValue->value);
}
-
- bindingAttr->binding = int32_t(binding->value);
- bindingAttr->set = int32_t(set->value);
+ values[i] = value;
}
- else if (auto maxVertexCountAttr = as<MaxVertexCountAttribute>(attr))
+
+ numThreadsAttr->x = values[0];
+ numThreadsAttr->y = values[1];
+ numThreadsAttr->z = values[2];
+ }
+ else if (auto bindingAttr = as<GLSLBindingAttribute>(attr))
+ {
+ // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding)
+ // Must have 2 int parameters. Ideally this would all be checked from the specification
+ // in core.meta.slang, but that's not completely implemented. So for now we check here.
+ if (attr->args.getCount() != 2)
{
- SLANG_ASSERT(attr->args.getCount() == 1);
- auto val = checkConstantIntVal(attr->args[0]);
+ return false;
+ }
- if(!val) return false;
+ // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here
+ // to make sure they both are
+ auto binding = checkConstantIntVal(attr->args[0]);
+ auto set = checkConstantIntVal(attr->args[1]);
- maxVertexCountAttr->value = (int32_t)val->value;
- }
- else if(auto instanceAttr = as<InstanceAttribute>(attr))
+ if (binding == nullptr || set == nullptr)
{
- SLANG_ASSERT(attr->args.getCount() == 1);
- auto val = checkConstantIntVal(attr->args[0]);
+ return false;
+ }
+
+ bindingAttr->binding = int32_t(binding->value);
+ bindingAttr->set = int32_t(set->value);
+ }
+ else if (auto maxVertexCountAttr = as<MaxVertexCountAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ auto val = checkConstantIntVal(attr->args[0]);
- if(!val) return false;
+ if(!val) return false;
- instanceAttr->value = (int32_t)val->value;
- }
- else if(auto entryPointAttr = as<EntryPointAttribute>(attr))
- {
- SLANG_ASSERT(attr->args.getCount() == 1);
+ maxVertexCountAttr->value = (int32_t)val->value;
+ }
+ else if(auto instanceAttr = as<InstanceAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ auto val = checkConstantIntVal(attr->args[0]);
- String stageName;
- if(!checkLiteralStringVal(attr->args[0], &stageName))
- {
- return false;
- }
+ if(!val) return false;
- auto stage = findStageByName(stageName);
- if(stage == Stage::Unknown)
- {
- getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName);
- }
+ instanceAttr->value = (int32_t)val->value;
+ }
+ else if(auto entryPointAttr = as<EntryPointAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
- entryPointAttr->stage = stage;
- }
- else if ((as<DomainAttribute>(attr)) ||
- (as<MaxTessFactorAttribute>(attr)) ||
- (as<OutputTopologyAttribute>(attr)) ||
- (as<PartitioningAttribute>(attr)) ||
- (as<PatchConstantFuncAttribute>(attr)))
- {
- // Let it go thru iff single string attribute
- if (!hasStringArgs(attr, 1))
- {
- getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name);
- }
- }
- else if (as<OutputControlPointsAttribute>(attr))
+ String stageName;
+ if(!checkLiteralStringVal(attr->args[0], &stageName))
{
- // Let it go thru iff single integral attribute
- if (!hasIntArgs(attr, 1))
- {
- getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name);
- }
+ return false;
}
- else if (as<PushConstantAttribute>(attr))
+
+ auto stage = findStageByName(stageName);
+ if(stage == Stage::Unknown)
{
- // Has no args
- SLANG_ASSERT(attr->args.getCount() == 0);
+ getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName);
}
- else if (as<ShaderRecordAttribute>(attr))
+
+ entryPointAttr->stage = stage;
+ }
+ else if ((as<DomainAttribute>(attr)) ||
+ (as<MaxTessFactorAttribute>(attr)) ||
+ (as<OutputTopologyAttribute>(attr)) ||
+ (as<PartitioningAttribute>(attr)) ||
+ (as<PatchConstantFuncAttribute>(attr)))
+ {
+ // Let it go thru iff single string attribute
+ if (!hasStringArgs(attr, 1))
{
- // Has no args
- SLANG_ASSERT(attr->args.getCount() == 0);
+ getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name);
}
- else if (as<EarlyDepthStencilAttribute>(attr))
+ }
+ else if (as<OutputControlPointsAttribute>(attr))
+ {
+ // Let it go thru iff single integral attribute
+ if (!hasIntArgs(attr, 1))
{
- // Has no args
- SLANG_ASSERT(attr->args.getCount() == 0);
+ getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name);
}
- else if (auto attrUsageAttr = as<AttributeUsageAttribute>(attr))
+ }
+ else if (as<PushConstantAttribute>(attr))
+ {
+ // Has no args
+ SLANG_ASSERT(attr->args.getCount() == 0);
+ }
+ else if (as<ShaderRecordAttribute>(attr))
+ {
+ // Has no args
+ SLANG_ASSERT(attr->args.getCount() == 0);
+ }
+ else if (as<EarlyDepthStencilAttribute>(attr))
+ {
+ // Has no args
+ SLANG_ASSERT(attr->args.getCount() == 0);
+ }
+ else if (auto attrUsageAttr = as<AttributeUsageAttribute>(attr))
+ {
+ uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None;
+ if (attr->args.getCount() == 1)
{
- uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None;
- if (attr->args.getCount() == 1)
+ RefPtr<IntVal> outIntVal;
+ if (auto cInt = checkConstantEnumVal(attr->args[0]))
{
- RefPtr<IntVal> outIntVal;
- if (auto cInt = checkConstantEnumVal(attr->args[0]))
- {
- targetClassId = (uint32_t)(cInt->value);
- }
- else
- {
- getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name);
- return false;
- }
+ targetClassId = (uint32_t)(cInt->value);
}
- if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId))
+ else
{
- getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget);
+ getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name);
return false;
}
}
- else if (auto unrollAttr = as<UnrollAttribute>(attr))
+ if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId))
{
- // Check has an argument. We need this because default behavior is to give an error
- // if an attribute has arguments, but not handled explicitly (and the default param will come through
- // as 1 arg if nothing is specified)
- SLANG_ASSERT(attr->args.getCount() == 1);
+ getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget);
+ return false;
}
- else if (auto userDefAttr = as<UserDefinedAttribute>(attr))
+ }
+ else if (auto unrollAttr = as<UnrollAttribute>(attr))
+ {
+ // Check has an argument. We need this because default behavior is to give an error
+ // if an attribute has arguments, but not handled explicitly (and the default param will come through
+ // as 1 arg if nothing is specified)
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ }
+ else if (auto userDefAttr = as<UserDefinedAttribute>(attr))
+ {
+ // check arguments against attribute parameters defined in attribClassDecl
+ Index paramIndex = 0;
+ auto params = attribClassDecl->getMembersOfType<ParamDecl>();
+ for (auto paramDecl : params)
{
- // check arguments against attribute parameters defined in attribClassDecl
- Index paramIndex = 0;
- auto params = attribClassDecl->getMembersOfType<ParamDecl>();
- for (auto paramDecl : params)
+ if (paramIndex < attr->args.getCount())
{
- if (paramIndex < attr->args.getCount())
+ auto & arg = attr->args[paramIndex];
+ bool typeChecked = false;
+ if (auto basicType = as<BasicExpressionType>(paramDecl->getType()))
{
- auto & arg = attr->args[paramIndex];
- bool typeChecked = false;
- if (auto basicType = as<BasicExpressionType>(paramDecl->getType()))
+ if (basicType->baseType == BaseType::Int)
{
- if (basicType->baseType == BaseType::Int)
+ if (auto cint = checkConstantIntVal(arg))
{
- if (auto cint = checkConstantIntVal(arg))
- {
- attr->intArgVals[(uint32_t)paramIndex] = cint;
- }
- typeChecked = true;
+ attr->intArgVals[(uint32_t)paramIndex] = cint;
}
- }
- if (!typeChecked)
- {
- arg = CheckExpr(arg);
- arg = coerce(paramDecl->getType(), arg);
+ typeChecked = true;
}
}
- paramIndex++;
- }
- if (params.getCount() < attr->args.getCount())
- {
- getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount());
- }
- else if (params.getCount() > attr->args.getCount())
- {
- getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount());
+ if (!typeChecked)
+ {
+ arg = CheckExpr(arg);
+ arg = coerce(paramDecl->getType(), arg);
+ }
}
+ paramIndex++;
}
- else if (auto formatAttr = as<FormatAttribute>(attr))
+ if (params.getCount() < attr->args.getCount())
{
- SLANG_ASSERT(attr->args.getCount() == 1);
-
- String formatName;
- if(!checkLiteralStringVal(attr->args[0], &formatName))
- {
- return false;
- }
-
- ImageFormat format = ImageFormat::unknown;
- if(!findImageFormatByName(formatName.getBuffer(), &format))
- {
- getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName);
- }
+ getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), params.getCount());
+ }
+ else if (params.getCount() > attr->args.getCount())
+ {
+ getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount());
+ }
+ }
+ else if (auto formatAttr = as<FormatAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
- formatAttr->format = format;
+ String formatName;
+ if(!checkLiteralStringVal(attr->args[0], &formatName))
+ {
+ return false;
}
- else if (auto allowAttr = as<AllowAttribute>(attr))
+
+ ImageFormat format = ImageFormat::unknown;
+ if(!findImageFormatByName(formatName.getBuffer(), &format))
{
- SLANG_ASSERT(attr->args.getCount() == 1);
+ getSink()->diagnose(attr->args[0], Diagnostics::unknownImageFormatName, formatName);
+ }
- String diagnosticName;
- if(!checkLiteralStringVal(attr->args[0], &diagnosticName))
- {
- return false;
- }
+ formatAttr->format = format;
+ }
+ else if (auto allowAttr = as<AllowAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
- auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice());
- if(!diagnosticInfo)
- {
- getSink()->diagnose(attr->args[0], Diagnostics::unknownDiagnosticName, diagnosticName);
- }
+ String diagnosticName;
+ if(!checkLiteralStringVal(attr->args[0], &diagnosticName))
+ {
+ return false;
+ }
- allowAttr->diagnostic = diagnosticInfo;
+ auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice());
+ if(!diagnosticInfo)
+ {
+ getSink()->diagnose(attr->args[0], Diagnostics::unknownDiagnosticName, diagnosticName);
+ }
+
+ allowAttr->diagnostic = diagnosticInfo;
+ }
+ else
+ {
+ if(attr->args.getCount() == 0)
+ {
+ // If the attribute took no arguments, then we will
+ // assume it is valid as written.
}
else
{
- if(attr->args.getCount() == 0)
- {
- // If the attribute took no arguments, then we will
- // assume it is valid as written.
- }
- else
- {
- // We should be special-casing the checking of any attribute
- // with a non-zero number of arguments.
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute");
- return false;
- }
+ // We should be special-casing the checking of any attribute
+ // with a non-zero number of arguments.
+ SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute");
+ return false;
}
+ }
- return true;
+ return true;
}
RefPtr<AttributeBase> checkAttribute(
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index c1c207fd7..067a65159 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -274,6 +274,7 @@ DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0
DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'")
DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'")
DIAGNOSTIC(31101, Error, unknownDiagnosticName, "unknown diagnostic '$0'")
+DIAGNOSTIC(31102, Error, nonPositiveNumThreads, "expected a positive integer in 'numthreads' attribute, got '$0'")
DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute")
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index a19823055..281da3ef9 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -203,6 +203,17 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
}
}
+
+/* static */IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount])
+{
+ IRNumThreadsDecoration* decor = func->findDecoration<IRNumThreadsDecoration>();
+ for (int i = 0; i < 3; ++i)
+ {
+ outNumThreads[i] = decor ? Int(GetIntVal(decor->getOperand(i))) : 1;
+ }
+ return decor;
+}
+
void CLikeSourceEmitter::_emitArrayType(IRArrayType* arrayType, EDeclarator* declarator)
{
EDeclarator arrayDeclarator;
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index ee906010b..9c62855aa 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -43,6 +43,11 @@ public:
// explicitly in the layout data.
StructTypeLayout* globalStructLayout = nullptr;
};
+
+ enum
+ {
+ kThreadGroupAxisCount = 3,
+ };
/// To simplify cases
enum class SourceStyle
@@ -318,6 +323,9 @@ public:
/// Returns an empty slice if not a built in type
static UnownedStringSlice getDefaultBuiltinTypeName(IROp op);
+ /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1
+ static IRNumThreadsDecoration* getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]);
+
protected:
virtual void emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling = "register") { SLANG_UNUSED(inst); SLANG_UNUSED(uniformSemanticSpelling); }
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 1628c6770..1f38512b4 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1916,23 +1916,15 @@ void CPPSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointLa
{
case Stage::Compute:
{
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize(
- (SlangReflectionEntryPoint*)entryPointLayout,
- kAxisCount,
- &sizeAlongAxis[0]);
-
+ Int numThreads[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(irFunc, numThreads);
+
// TODO(JS): We might want to store this information such that it can be used to execute
m_writer->emit("// [numthreads(");
- for (int ii = 0; ii < 3; ++ii)
+ for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
{
if (ii != 0) m_writer->emit(", ");
- m_writer->emit(sizeAlongAxis[ii]);
+ m_writer->emit(numThreads[ii]);
}
m_writer->emit(")]\n");
break;
@@ -2504,16 +2496,16 @@ struct AxisWithSize
bool operator<(const ThisType& rhs) const { return size < rhs.size || (size == rhs.size && axis < rhs.axis); }
int axis;
- UInt size;
+ Int size;
};
} // anonymous
-static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<AxisWithSize>& out)
+static void _calcAxisOrder(const Int sizeAlongAxis[CLikeSourceEmitter::kThreadGroupAxisCount], bool allowSingle, List<AxisWithSize>& out)
{
out.clear();
// Add in order z,y,x, so by default (if we don't sort), x will be the inner loop
- for (int i = 3 - 1; i >= 0; --i)
+ for (int i = CLikeSourceEmitter::kThreadGroupAxisCount - 1; i >= 0; --i)
{
if (allowSingle || sizeAlongAxis[i] > 1)
{
@@ -2529,7 +2521,7 @@ static void _calcAxisOrder(const UInt sizeAlongAxis[3], bool allowSingle, List<A
// axes.sort();
}
-void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName)
+void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName)
{
List<AxisWithSize> axes;
_calcAxisOrder(sizeAlongAxis, false, axes);
@@ -2563,7 +2555,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const UInt sizeAlongAxis[3], const S
}
}
-void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName)
+void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName)
{
List<AxisWithSize> axes;
_calcAxisOrder(sizeAlongAxis, true, axes);
@@ -2635,13 +2627,13 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const UInt sizeAlongAxis[3], co
m_writer->emit("}\n");
}
}
-void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName)
+void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName)
{
StringBuilder builder;
m_writer->emit("{\n");
m_writer->indent();
- for (int i = 0; i < 3; ++i)
+ for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
builder.Clear();
const char elem[2] = { s_elemNames[i], 0 };
@@ -2650,7 +2642,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const UInt sizeAlongAxis[3], const Un
{
builder << " + " << addName << "." << elem;
}
- if (i < 3 - 1)
+ if (i < kThreadGroupAxisCount - 1)
{
builder << ",";
}
@@ -2821,14 +2813,9 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid
// SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID.
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*)entryPointLayout, kAxisCount, &sizeAlongAxis[0]);
-
+ Int groupThreadSize[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(func, groupThreadSize);
+
String funcName = getFuncName(func);
{
@@ -2842,7 +2829,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
m_writer->emit("context.groupDispatchThreadID = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice());
}
if (m_semanticUsedFlags & SemanticUsedFlag::GroupID)
{
@@ -2851,7 +2838,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
// Emit dispatchThreadID
m_writer->emit("context.dispatchThreadID = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->groupID"), UnownedStringSlice::fromLiteral("varyingInput->groupThreadID"));
m_writer->emit("context._");
m_writer->emit(funcName);
@@ -2871,7 +2858,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointGlobalParams, groupFuncName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
if (m_semanticUsedFlags & SemanticUsedFlag::GroupThreadID)
{
@@ -2884,7 +2871,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
}
m_writer->emit("context.dispatchThreadID = start;\n");
- _emitEntryPointGroup(sizeAlongAxis, funcName);
+ _emitEntryPointGroup(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
}
@@ -2893,11 +2880,11 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
_emitEntryPointDefinitionStart(func, entryPointGlobalParams, funcName, UnownedStringSlice::fromLiteral("ComputeVaryingInput"));
m_writer->emit("const uint3 start = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->startGroupID"), UnownedStringSlice());
m_writer->emit("const uint3 end = ");
- _emitInitAxisValues(sizeAlongAxis, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
+ _emitInitAxisValues(groupThreadSize, UnownedStringSlice::fromLiteral("varyingInput->endGroupID"), UnownedStringSlice());
- _emitEntryPointGroupRange(sizeAlongAxis, funcName);
+ _emitEntryPointGroupRange(groupThreadSize, funcName);
_emitEntryPointDefinitionEnd(func);
}
}
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 0e182818d..af7aef271 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -271,15 +271,15 @@ protected:
void _emitEntryPointDefinitionStart(IRFunc* func, IRGlobalParam* entryPointGlobalParams, const String& funcName, const UnownedStringSlice& varyingTypeName);
void _emitEntryPointDefinitionEnd(IRFunc* func);
- void _emitEntryPointGroup(const UInt sizeAlongAxis[3], const String& funcName);
- void _emitEntryPointGroupRange(const UInt sizeAlongAxis[3], const String& funcName);
+ void _emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName);
+ void _emitEntryPointGroupRange(const Int sizeAlongAxis[kThreadGroupAxisCount], const String& funcName);
- void _emitInitAxisValues(const UInt sizeAlongAxis[3], const UnownedStringSlice& mulName, const UnownedStringSlice& addName);
+ void _emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName);
Dictionary<SpecializedIntrinsic, StringSlicePool::Handle> m_intrinsicNameMap;
Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap;
- /* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsice.
+ /* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic.
That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to
work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms
of other passes.
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index a1c7b9170..e43d41d5c 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -661,20 +661,12 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL
{
case Stage::Compute:
{
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize(
- (SlangReflectionEntryPoint*)entryPointLayout,
- kAxisCount,
- &sizeAlongAxis[0]);
+ Int sizeAlongAxis[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(irFunc, sizeAlongAxis);
m_writer->emit("layout(");
char const* axes[] = { "x", "y", "z" };
- for (int ii = 0; ii < 3; ++ii)
+ for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
{
if (ii != 0) m_writer->emit(", ");
m_writer->emit("local_size_");
@@ -687,16 +679,19 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, EntryPointL
break;
case Stage::Geometry:
{
- if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<MaxVertexCountAttribute>())
+ if (auto decor = irFunc->findDecoration<IRMaxVertexCountDecoration>())
{
+ auto count = GetIntVal(decor->getCount());
m_writer->emit("layout(max_vertices = ");
- m_writer->emit(attrib->value);
+ m_writer->emit(Int(count));
m_writer->emit(") out;\n");
}
- if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<InstanceAttribute>())
+
+ if (auto decor = irFunc->findDecoration<IRInstanceDecoration>())
{
+ auto count = GetIntVal(decor->getCount());
m_writer->emit("layout(invocations = ");
- m_writer->emit(attrib->value);
+ m_writer->emit(Int(count));
m_writer->emit(") in;\n");
}
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 4abd692f8..a5c4ae088 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -10,58 +10,29 @@
namespace Slang {
-
-void HLSLSourceEmitter::_emitHLSLAttributeSingleString(const char* name, FuncDecl* entryPoint, Attribute* attrib)
+void HLSLSourceEmitter::_emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val)
{
- assert(attrib);
-
- attrib->args.getCount();
- if (attrib->args.getCount() != 1)
- {
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects single parameter");
- return;
- }
-
- Expr* expr = attrib->args[0];
-
- auto stringLitExpr = as<StringLiteralExpr>(expr);
- if (!stringLitExpr)
- {
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute parameter expecting to be a string ");
- return;
- }
+ SLANG_UNUSED(entryPoint);
+ assert(val);
m_writer->emit("[");
m_writer->emit(name);
m_writer->emit("(\"");
- m_writer->emit(stringLitExpr->value);
+ m_writer->emit(val->getStringSlice());
m_writer->emit("\")]\n");
}
-void HLSLSourceEmitter::_emitHLSLAttributeSingleInt(const char* name, FuncDecl* entryPoint, Attribute* attrib)
+void HLSLSourceEmitter::_emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val)
{
- assert(attrib);
+ SLANG_UNUSED(entryPoint);
+ SLANG_ASSERT(val);
- attrib->args.getCount();
- if (attrib->args.getCount() != 1)
- {
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects single parameter");
- return;
- }
-
- Expr* expr = attrib->args[0];
-
- auto intLitExpr = as<IntegerLiteralExpr>(expr);
- if (!intLitExpr)
- {
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects an int");
- return;
- }
+ auto intVal = GetIntVal(val);
m_writer->emit("[");
m_writer->emit(name);
m_writer->emit("(");
- m_writer->emit(intLitExpr->value);
+ m_writer->emit(intVal);
m_writer->emit(")]\n");
}
@@ -272,19 +243,11 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint
{
case Stage::Compute:
{
- static const UInt kAxisCount = 3;
- UInt sizeAlongAxis[kAxisCount];
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize(
- (SlangReflectionEntryPoint*)entryPointLayout,
- kAxisCount,
- &sizeAlongAxis[0]);
+ Int sizeAlongAxis[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(irFunc, sizeAlongAxis);
m_writer->emit("[numthreads(");
- for (int ii = 0; ii < 3; ++ii)
+ for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
{
if (ii != 0) m_writer->emit(", ");
m_writer->emit(sizeAlongAxis[ii]);
@@ -294,29 +257,30 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint
break;
case Stage::Geometry:
{
- if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<MaxVertexCountAttribute>())
+ if (auto decor = irFunc->findDecoration<IRMaxVertexCountDecoration>())
{
+ auto count = GetIntVal(decor->getCount());
m_writer->emit("[maxvertexcount(");
- m_writer->emit(attrib->value);
+ m_writer->emit(Int(count));
m_writer->emit(")]\n");
}
- if (auto attrib = entryPointLayout->getFuncDecl()->FindModifier<InstanceAttribute>())
+
+ if (auto decor = irFunc->findDecoration<IRInstanceDecoration>())
{
+ auto count = GetIntVal(decor->getCount());
m_writer->emit("[instance(");
- m_writer->emit(attrib->value);
+ m_writer->emit(Int(count));
m_writer->emit(")]\n");
}
break;
}
case Stage::Domain:
{
- FuncDecl* entryPoint = entryPointLayout->entryPoint;
/* [domain("isoline")] */
- if (auto attrib = entryPoint->FindModifier<DomainAttribute>())
+ if (auto decor = irFunc->findDecoration<IRDomainDecoration>())
{
- _emitHLSLAttributeSingleString("domain", entryPoint, attrib);
+ _emitHLSLDecorationSingleString("domain", irFunc, decor->getDomain());
}
-
break;
}
case Stage::Hull:
@@ -324,32 +288,38 @@ void HLSLSourceEmitter::_emitHLSLEntryPointAttributes(IRFunc* irFunc, EntryPoint
// Lists these are only attributes for hull shader
// https://docs.microsoft.com/en-us/windows/desktop/direct3d11/direct3d-11-advanced-stages-hull-shader-design
- FuncDecl* entryPoint = entryPointLayout->entryPoint;
-
/* [domain("isoline")] */
- if (auto attrib = entryPoint->FindModifier<DomainAttribute>())
+ if (auto decor = irFunc->findDecoration<IRDomainDecoration>())
{
- _emitHLSLAttributeSingleString("domain", entryPoint, attrib);
+ _emitHLSLDecorationSingleString("domain", irFunc, decor->getDomain());
}
+
/* [domain("partitioning")] */
- if (auto attrib = entryPoint->FindModifier<PartitioningAttribute>())
+ if (auto decor = irFunc->findDecoration<IRPartitioningDecoration>())
{
- _emitHLSLAttributeSingleString("partitioning", entryPoint, attrib);
+ _emitHLSLDecorationSingleString("partitioning", irFunc, decor->getPartitioning());
}
+
/* [outputtopology("line")] */
- if (auto attrib = entryPoint->FindModifier<OutputTopologyAttribute>())
+ if (auto decor = irFunc->findDecoration<IROutputTopologyDecoration>())
{
- _emitHLSLAttributeSingleString("outputtopology", entryPoint, attrib);
+ _emitHLSLDecorationSingleString("outputtopology", irFunc, decor->getTopology());
}
+
/* [outputcontrolpoints(4)] */
- if (auto attrib = entryPoint->FindModifier<OutputControlPointsAttribute>())
+ if (auto decor = irFunc->findDecoration<IROutputControlPointsDecoration>())
{
- _emitHLSLAttributeSingleInt("outputcontrolpoints", entryPoint, attrib);
+ _emitHLSLDecorationSingleInt("outputcontrolpoints", irFunc, decor->getControlPointCount());
}
+
/* [patchconstantfunc("HSConst")] */
- if (auto attrib = entryPoint->FindModifier<PatchConstantFuncAttribute>())
+ if (auto decor = irFunc->findDecoration<IRPatchConstantFuncDecoration>())
{
- _emitHLSLFuncDeclPatchConstantFuncAttribute(irFunc, entryPoint, attrib);
+ const String irName = getName(decor->getFunc());
+
+ m_writer->emit("[patchconstantfunc(\"");
+ m_writer->emit(irName);
+ m_writer->emit("\")]\n");
}
break;
@@ -422,25 +392,6 @@ void HLSLSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType)
m_writer->emit(" >");
}
-void HLSLSourceEmitter::_emitHLSLFuncDeclPatchConstantFuncAttribute(IRFunc* irFunc, FuncDecl* entryPoint, PatchConstantFuncAttribute* attrib)
-{
- SLANG_UNUSED(attrib);
-
- auto irPatchFunc = irFunc->findDecoration<IRPatchConstantFuncDecoration>();
- assert(irPatchFunc);
- if (!irPatchFunc)
- {
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Unable to find [patchConstantFunc(...)] decoration");
- return;
- }
-
- const String irName = getName(irPatchFunc->getFunc());
-
- m_writer->emit("[patchconstantfunc(\"");
- m_writer->emit(irName);
- m_writer->emit("\")]\n");
-}
-
void HLSLSourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling)
{
auto layout = getVarLayout(inst);
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index 6b4dfaa5c..b6b5fa740 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -50,12 +50,9 @@ protected:
void _emitHLSLTextureType(IRTextureTypeBase* texType);
- void _emitHLSLFuncDeclPatchConstantFuncAttribute(IRFunc* irFunc, FuncDecl* entryPoint, PatchConstantFuncAttribute* attrib);
-
- void _emitHLSLAttributeSingleString(const char* name, FuncDecl* entryPoint, Attribute* attrib);
-
- void _emitHLSLAttributeSingleInt(const char* name, FuncDecl* entryPoint, Attribute* attrib);
-
+ void _emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val);
+ void _emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val);
+
};
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 49b6138ed..b586aa5bc 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -414,6 +414,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(PreciseDecoration, precise, 0, 0)
INST(PatchConstantFuncDecoration, patchConstantFunc, 1, 0)
+ INST(OutputControlPointsDecoration, outputControlPoints, 1, 0)
+ INST(OutputTopologyDecoration, outputTopology, 1, 0)
+ INST(PartitioningDecoration, partioning, 1, 0)
+ INST(DomainDecoration, domain, 1, 0)
+ INST(MaxVertexCountDecoration, maxVertexCount, 1, 0)
+ INST(InstanceDecoration, instance, 1, 0)
+ INST(NumThreadsDecoration, numThreads, 3, 0)
+
/// An `[entryPoint]` decoration marks a function that represents a shader entry point.
/// Also used in some scenarios mark parameters that are moved from entry point parameters to global params as coming from the entry point.
INST(EntryPointDecoration, entryPoint, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2e2168b1a..9dc80beac 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -228,6 +228,65 @@ IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration)
IR_SIMPLE_DECORATION(GloballyCoherentDecoration)
IR_SIMPLE_DECORATION(PreciseDecoration)
+
+struct IROutputControlPointsDecoration : IRDecoration
+{
+ enum { kOp = kIROp_OutputControlPointsDecoration };
+ IR_LEAF_ISA(OutputControlPointsDecoration)
+
+ IRIntLit* getControlPointCount() { return cast<IRIntLit>(getOperand(0)); }
+};
+
+struct IROutputTopologyDecoration : IRDecoration
+{
+ enum { kOp = kIROp_OutputTopologyDecoration };
+ IR_LEAF_ISA(OutputTopologyDecoration)
+
+ IRStringLit* getTopology() { return cast<IRStringLit>(getOperand(0)); }
+};
+
+struct IRPartitioningDecoration : IRDecoration
+{
+ enum { kOp = kIROp_PartitioningDecoration };
+ IR_LEAF_ISA(PartitioningDecoration)
+
+ IRStringLit* getPartitioning() { return cast<IRStringLit>(getOperand(0)); }
+};
+
+struct IRDomainDecoration : IRDecoration
+{
+ enum { kOp = kIROp_DomainDecoration };
+ IR_LEAF_ISA(DomainDecoration)
+
+ IRStringLit* getDomain() { return cast<IRStringLit>(getOperand(0)); }
+};
+
+struct IRMaxVertexCountDecoration : IRDecoration
+{
+ enum { kOp = kIROp_MaxVertexCountDecoration };
+ IR_LEAF_ISA(MaxVertexCountDecoration)
+
+ IRIntLit* getCount() { return cast<IRIntLit>(getOperand(0)); }
+};
+
+struct IRInstanceDecoration : IRDecoration
+{
+ enum { kOp = kIROp_InstanceDecoration };
+ IR_LEAF_ISA(InstanceDecoration)
+
+ IRIntLit* getCount() { return cast<IRIntLit>(getOperand(0)); }
+};
+
+struct IRNumThreadsDecoration : IRDecoration
+{
+ enum { kOp = kIROp_NumThreadsDecoration };
+ IR_LEAF_ISA(NumThreadsDecoration)
+
+ IRIntLit* getX() { return cast<IRIntLit>(getOperand(0)); }
+ IRIntLit* getY() { return cast<IRIntLit>(getOperand(1)); }
+ IRIntLit* getZ() { return cast<IRIntLit>(getOperand(2)); }
+};
+
/// A decoration that marks a value as having linkage.
///
/// A value with linkage is either exported from its module,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 1edd7d331..a622c9802 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4075,6 +4075,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
IRGenContext* context;
+ DiagnosticSink* getSink() { return context->getSink(); }
+
IRBuilder* getBuilder()
{
return context->irBuilder;
@@ -5573,6 +5575,27 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ IRIntLit* _getIntLitFromAttribute(IRBuilder* builder, Attribute* attrib)
+ {
+ attrib->args.getCount();
+ SLANG_ASSERT(attrib->args.getCount() ==1);
+ Expr* expr = attrib->args[0];
+ auto intLitExpr = as<IntegerLiteralExpr>(expr);
+ SLANG_ASSERT(intLitExpr);
+ return as<IRIntLit>(builder->getIntValue(builder->getIntType(), intLitExpr->value));
+ }
+
+ IRStringLit* _getStringLitFromAttribute(IRBuilder* builder, Attribute* attrib)
+ {
+ attrib->args.getCount();
+ SLANG_ASSERT(attrib->args.getCount() == 1);
+ Expr* expr = attrib->args[0];
+
+ auto stringLitExpr = as<StringLiteralExpr>(expr);
+ SLANG_ASSERT(stringLitExpr);
+ return as<IRStringLit>(builder->getStringValue(stringLitExpr->value.getUnownedSlice()));
+ }
+
LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
{
// We are going to use a nested builder, because we will
@@ -5912,6 +5935,32 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken)));
}
+ if (auto attr = decl->FindModifier<InstanceAttribute>())
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
+ getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit);
+ }
+
+ if(auto attr = decl->FindModifier<MaxVertexCountAttribute>())
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
+ getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit);
+ }
+
+ if(auto attr = decl->FindModifier<NumThreadsAttribute>())
+ {
+ auto builder = getBuilder();
+ IRType* intType = builder->getIntType();
+
+ IRInst* operands[3] = {
+ builder->getIntValue(intType, attr->x),
+ builder->getIntValue(intType, attr->y),
+ builder->getIntValue(intType, attr->z)
+ };
+
+ builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
+ }
+
if(decl->FindModifier<ReadNoneAttribute>())
{
getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
@@ -5922,6 +5971,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc);
}
+ if (auto attr = decl->FindModifier<DomainAttribute>())
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
+ getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit);
+ }
+
+ if (auto attr = decl->FindModifier<PartitioningAttribute>())
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
+ getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit);
+ }
+
+ if (auto attr = decl->FindModifier<OutputTopologyAttribute>())
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
+ getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit);
+ }
+
+ if (auto attr = decl->FindModifier<OutputControlPointsAttribute>())
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
+ getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit);
+ }
+
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
diff --git a/source/slang/slang-modifier-defs.h b/source/slang/slang-modifier-defs.h
index 754e7e44a..5ac61b991 100644
--- a/source/slang/slang-modifier-defs.h
+++ b/source/slang/slang-modifier-defs.h
@@ -131,11 +131,6 @@ SIMPLE_SYNTAX_CLASS(GLSLUnparsedLayoutModifier , GLSLLayoutModifier)
SIMPLE_SYNTAX_CLASS(GLSLConstantIDLayoutModifier , GLSLParsedLayoutModifier)
SIMPLE_SYNTAX_CLASS(GLSLLocationLayoutModifier , GLSLParsedLayoutModifier)
-SIMPLE_SYNTAX_CLASS(GLSLLocalSizeLayoutModifier, GLSLUnparsedLayoutModifier)
-SIMPLE_SYNTAX_CLASS(GLSLLocalSizeXLayoutModifier, GLSLLocalSizeLayoutModifier)
-SIMPLE_SYNTAX_CLASS(GLSLLocalSizeYLayoutModifier, GLSLLocalSizeLayoutModifier)
-SIMPLE_SYNTAX_CLASS(GLSLLocalSizeZLayoutModifier, GLSLLocalSizeLayoutModifier)
-
// A catch-all for single-keyword modifiers
SIMPLE_SYNTAX_CLASS(SimpleModifier, Modifier)
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index fffbf2c62..dad84a4b6 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -4450,6 +4450,8 @@ namespace Slang
{
ModifierListBuilder listBuilder;
+ RefPtr<UncheckedAttribute> numThreadsAttrib;
+
listBuilder.add(new GLSLLayoutModifierGroupBegin());
parser->ReadToken(TokenType::LParent);
@@ -4458,7 +4460,41 @@ namespace Slang
auto nameAndLoc = expectIdentifier(parser);
const String& nameText = nameAndLoc.name->text;
- if (nameText == "binding" ||
+ const char localSizePrefix[] = "local_size_";
+
+ int localSizeIndex = -1;
+ if (nameText.startsWith(localSizePrefix) && nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1)
+ {
+ char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1];
+ localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1;
+ }
+
+ if (localSizeIndex >= 0)
+ {
+ if (!numThreadsAttrib)
+ {
+ numThreadsAttrib = new UncheckedAttribute;
+ numThreadsAttrib->args.setCount(3);
+
+ // Just mark the loc and name from the first in the list
+ numThreadsAttrib->name = getName(parser, "numthreads");
+ numThreadsAttrib->loc = nameAndLoc.loc;
+ numThreadsAttrib->scope = parser->currentScope;
+ }
+
+ if (AdvanceIf(parser, TokenType::OpAssign))
+ {
+ auto expr = parseAtomicExpr(parser);
+ //SLANG_ASSERT(expr);
+ if (!expr)
+ {
+ return nullptr;
+ }
+
+ numThreadsAttrib->args[localSizeIndex] = expr;
+ }
+ }
+ else if (nameText == "binding" ||
nameText == "set")
{
GLSLBindingAttribute* attr = listBuilder.find<GLSLBindingAttribute>();
@@ -4499,9 +4535,6 @@ namespace Slang
CASE(shaderRecordNV, ShaderRecordAttribute)
CASE(constant_id, GLSLConstantIDLayoutModifier)
CASE(location, GLSLLocationLayoutModifier)
- CASE(local_size_x, GLSLLocalSizeXLayoutModifier)
- CASE(local_size_y, GLSLLocalSizeYLayoutModifier)
- CASE(local_size_z, GLSLLocalSizeZLayoutModifier)
{
modifier = new GLSLUnparsedLayoutModifier();
}
@@ -4528,6 +4561,11 @@ namespace Slang
parser->ReadToken(TokenType::Comma);
}
+ if (numThreadsAttrib)
+ {
+ listBuilder.add(numThreadsAttrib);
+ }
+
listBuilder.add(new GLSLLayoutModifierGroupEnd());
return listBuilder.getFirst();
diff --git a/source/slang/slang-reflection.cpp b/source/slang/slang-reflection.cpp
index a6015b31c..f2c4973ed 100644
--- a/source/slang/slang-reflection.cpp
+++ b/source/slang/slang-reflection.cpp
@@ -1267,27 +1267,6 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
sizeAlongAxis[1] = numThreadsAttribute->y;
sizeAlongAxis[2] = numThreadsAttribute->z;
}
- else
- {
- // Fall back to the GLSL case, which requires a search over global-scope declarations
- // to look for as with the `local_size_*` qualifier
- auto module = as<ModuleDecl>(entryPointFunc.getDecl()->ParentDecl);
- if (module)
- {
- for (auto dd : module->Members)
- {
- for (auto mod : dd->GetModifiersOfType<GLSLLocalSizeLayoutModifier>())
- {
- if (auto xMod = as<GLSLLocalSizeXLayoutModifier>(mod))
- sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken);
- else if (auto yMod = as<GLSLLocalSizeYLayoutModifier>(mod))
- sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken);
- else if (auto zMod = as<GLSLLocalSizeZLayoutModifier>(mod))
- sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken);
- }
- }
- }
- }
//