summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-specialize.cpp92
-rw-r--r--tests/bugs/gh-6589.slang63
2 files changed, 116 insertions, 39 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index a200a907b..2f51b28a2 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -1787,32 +1787,38 @@ struct SpecializationContext
// created.
//
auto valType = val->getFullType();
- if (auto extractExistentialType = as<IRExtractExistentialType>(valType))
+ if (isCompileTimeConstantType(valType) &&
+ isCompileTimeConstantType(oldParam->getFullType()))
{
- valType = extractExistentialType->getOperand(0)->getDataType();
- auto newParam = builder->createParam(valType);
- newParams.add(newParam);
- replacementVal = newParam;
- }
- else
- {
- auto newParam = builder->createParam(valType);
- newParams.add(newParam);
+ if (auto extractExistentialType = as<IRExtractExistentialType>(valType))
+ {
+ valType = extractExistentialType->getOperand(0)->getDataType();
+ auto newParam = builder->createParam(valType);
+ newParams.add(newParam);
+ replacementVal = newParam;
+ }
+ else
+ {
+ auto newParam = builder->createParam(valType);
+ newParams.add(newParam);
- // Within the body of the function we cannot just use `val`
- // directly, because the existing code expects an existential
- // value, including its witness table.
- //
- // Therefore we will create a `makeExistential(newParam, witnessTable)`
- // in the body of the new function and use *that* as the replacement
- // value for the original parameter (since it will have the
- // correct existential type, and stores the right witness table).
- //
- auto newMakeExistential = builder->emitMakeExistential(
- oldParam->getFullType(),
- newParam,
- witnessTable);
- replacementVal = newMakeExistential;
+ // Within the body of the function we cannot just use `val`
+ // directly, because the existing code expects an existential
+ // value, including its witness table.
+ //
+ // Therefore we will create a `makeExistential(newParam, witnessTable)`
+ // in the body of the new function and use *that* as the replacement
+ // value for the original parameter (since it will have the
+ // correct existential type, and stores the right witness table).
+ //
+ auto newMakeExistential = builder->emitMakeExistential(
+ oldParam->getFullType(),
+ newParam,
+ witnessTable);
+ replacementVal = newMakeExistential;
+ }
+ cloneEnv.mapOldValToNew.add(oldParam, replacementVal);
+ continue;
}
}
else if (auto oldWrapExistential = as<IRWrapExistential>(arg))
@@ -1837,24 +1843,32 @@ struct SpecializationContext
newParam,
oldWrapExistential->getSlotOperandCount(),
oldWrapExistential->getSlotOperands());
- replacementVal = newWrapExistential;
- }
- else
- {
- // For parameters that don't have an existential type,
- // there is nothing interesting to do. The new function
- // will also have a parameter of the exact same type,
- // and we'll use that instead of the original parameter.
- //
- auto newParam = builder->createParam(oldParam->getFullType());
- newParams.add(newParam);
- replacementVal = newParam;
+ cloneEnv.mapOldValToNew.add(oldParam, newWrapExistential);
+ continue;
}
- // Whatever replacement value was constructed, we need to
- // register it as the replacement for the original parameter.
+ // If we go here, then the parameter is either not an existential type,
+ // or the argument/parameter is not specialized yet.
+ //
+ // For first case there is nothing interesting to do. The new function
+ // will also have a parameter of the exact same type, and we'll use that
+ // instead of the original parameter.
+ //
+ //
+ // For the second case if the argument/parameter is not specialized yet, don't
+ // aggressively specialize the parameter.
+ //
+ // If we specialize the parameter type too early, we will lose the opportunity
+ // to specialize the callee later. The principal is to always let the
+ // specialization happen at the same time for both on argument and parameter.
//
- cloneEnv.mapOldValToNew.add(oldParam, replacementVal);
+ // Note we should not use `createParam` here, because this call won't assign the
+ // parent to the new parameter, therefore during the cloning process, some
+ // existential related IR inst could be hoisted to the global scope, which is
+ // unexpected. Instead, we should use cloneInst here, such that the new
+ // parameter will be inserted into the function scope.
+ auto newParam = (IRParam*)cloneInst(&cloneEnv, builder, oldParam);
+ newParams.add(newParam);
}
// The above steps have accomplished the "first phase"
diff --git a/tests/bugs/gh-6589.slang b/tests/bugs/gh-6589.slang
new file mode 100644
index 000000000..9433f510e
--- /dev/null
+++ b/tests/bugs/gh-6589.slang
@@ -0,0 +1,63 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+// This is a test that checks that we can apply partial specialization to a function
+// we won't specialize the function parameters too aggressively. Instead, we will specialize
+// the parameters at the same time of specializing the arguments. Otherwise, we could lose
+// the chance to specialize the argument.
+//
+// In this test, `matrix_vector_interfaces` will be fully specialized, otherwise the compile
+// will fail because we don't allow opaque type in the existential type. So as long as the target
+// spirv code can be generated, we are good.
+
+// CHECK: %main
+public interface ITensor<T : IDifferentiable, let D : int>
+{
+ public T get(int idx);
+
+}
+
+public interface IRWTensor<T : IDifferentiable, let D : int> : ITensor<T, D>
+{
+}
+
+
+public struct RWTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D>
+{
+ public RWStructuredBuffer<T> buffer;
+ public T get(int idx) { return buffer[idx]; }
+}
+
+public struct GradInOutTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D>
+{
+ public RWTensor<T, D> primal;
+ public T get(int idx) { return primal.get(idx); }
+}
+
+struct CallData
+{
+ GradInOutTensor<float, 3> weights;
+ GradInOutTensor<float, 2> biases;
+ RWStructuredBuffer<float> _result;
+}
+ParameterBlock<CallData> call_data;
+
+float matrix_vector_interfaces(ITensor<float, 2> weights, ITensor<float, 1> biases)
+{
+ return weights.get(0);
+}
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void main(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float _result;
+ GradInOutTensor<float, 2> weights;
+ GradInOutTensor<float, 1> biases;
+
+ weights.primal.buffer = call_data.weights.primal.buffer;
+ biases.primal.buffer = call_data.biases.primal.buffer;
+
+ _result = matrix_vector_interfaces(weights, biases);
+
+ call_data._result[0] = _result;
+}