summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-modifier.h2
-rw-r--r--source/slang/slang-check-modifier.cpp6
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp16
-rw-r--r--tests/bugs/gh-5781.slang57
-rw-r--r--tests/language-feature/constants/max-iters-link-time-const.slang15
7 files changed, 93 insertions, 14 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index ee9b55334..863ffaef2 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -744,7 +744,7 @@ class MaxItersAttribute : public Attribute
{
SLANG_AST_CLASS(MaxItersAttribute)
- int32_t value = 0;
+ IntVal* value = 0;
};
// An inferred max iteration count on a loop.
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 0794279a7..a1e0f7876 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -720,11 +720,7 @@ Modifier* SemanticsVisitor::validateAttribute(
}
else
{
- auto cint = checkConstantIntVal(attr->args[0]);
- if (cint)
- {
- maxItersAttrs->value = (int32_t)cint->getValue();
- }
+ maxItersAttrs->value = checkLinkTimeConstantIntVal(attr->args[0]);
}
}
else if (const auto userDefAttr = as<UserDefinedAttribute>(attr))
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 4c7272755..829e72575 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4637,9 +4637,14 @@ public:
getIntValue(getIntType(), IRIntegerValue(mode)));
}
- void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters)
+ void addLoopMaxItersDecoration(IRInst* value, IRIntegerValue iters)
{
- addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters));
+ addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(iters));
+ }
+
+ void addLoopMaxItersDecoration(IRInst* value, IRInst* iters)
+ {
+ addDecoration(value, kIROp_LoopMaxItersDecoration, iters);
}
void addLoopForceUnrollDecoration(IRInst* value, IntegerLiteralValue iters)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f6c662a98..29fbcc3c9 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3890,6 +3890,8 @@ enum class TypeCastStyle
};
static TypeCastStyle _getTypeStyleId(IRType* type)
{
+ type = (IRType*)unwrapAttributedType(type);
+
if (auto vectorType = as<IRVectorType>(type))
{
return _getTypeStyleId(vectorType->getElementType());
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 75cf421af..ce6f8cb42 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -5938,7 +5938,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
if (auto maxItersAttr = stmt->findModifier<MaxItersAttribute>())
{
- getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value);
+ auto iters = lowerVal(context, maxItersAttr->value);
+ getBuilder()->addLoopMaxItersDecoration(inst, getSimpleVal(context, iters));
}
else if (auto inferredMaxItersAttr = stmt->findModifier<InferredMaxItersAttribute>())
{
@@ -6028,12 +6029,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
{
if (auto maxIters = stmt->findModifier<MaxItersAttribute>())
{
- if (inferredMaxIters->value < maxIters->value)
+ if (auto constIntVal = as<ConstantIntVal>(maxIters->value))
{
- context->getSink()->diagnose(
- maxIters,
- Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
- inferredMaxIters->value);
+ if (inferredMaxIters->value < constIntVal->getValue())
+ {
+ context->getSink()->diagnose(
+ maxIters,
+ Diagnostics::forLoopTerminatesInFewerIterationsThanMaxIters,
+ inferredMaxIters->value);
+ }
}
}
}
diff --git a/tests/bugs/gh-5781.slang b/tests/bugs/gh-5781.slang
new file mode 100644
index 000000000..33456f500
--- /dev/null
+++ b/tests/bugs/gh-5781.slang
@@ -0,0 +1,57 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+// CHECK: OpEntryPoint
+
+module test;
+
+public enum class MaterialID : uint { invalid = 0xffffffff };
+
+public struct Material : IDifferentiable
+{
+ float x;
+}
+
+public struct Hit
+{
+ MaterialID material;
+}
+
+public struct Scene
+{
+ StructuredBuffer<Material> materials;
+ RWStructuredBuffer<Material> grads;
+
+ [Differentiable]
+ Material load(MaterialID id) { return materials[uint(id)]; }
+
+ void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; }
+
+ [Differentiable, BackwardDerivative(_get_material_bwd)]
+ public Material get_material(MaterialID id) { return load(id); }
+
+ public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); }
+
+ [Differentiable]
+ public Material get_material(Hit hit) { return get_material(hit.material); }
+}
+
+[Differentiable]
+float trace(const Scene scene, Hit hit)
+{
+ Material m = scene.get_material(hit);
+ return m.x;
+}
+
+
+[shader("compute")]
+void main(
+ uniform Scene scene,
+ uniform StructuredBuffer<uint> input,
+ uniform RWStructuredBuffer<float> output,
+ uniform RWStructuredBuffer<float> grads
+)
+{
+ Hit hit;
+ hit.material = MaterialID(input[0]);
+ output[0] = trace(scene, hit);
+ bwd_diff(trace)(scene, hit, grads[0]);
+} \ No newline at end of file
diff --git a/tests/language-feature/constants/max-iters-link-time-const.slang b/tests/language-feature/constants/max-iters-link-time-const.slang
new file mode 100644
index 000000000..cf1ccbbd1
--- /dev/null
+++ b/tests/language-feature/constants/max-iters-link-time-const.slang
@@ -0,0 +1,15 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+// CHECK: OpEntryPoint
+
+extern static const int num = 10;
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ [MaxIters(num)]
+ for (int i = 0; i < num; i++)
+ {
+ outputBuffer[i] = i;
+ }
+}