diff options
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 92 | ||||
| -rw-r--r-- | tests/bugs/gh-6589.slang | 63 |
2 files changed, 116 insertions, 39 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index a200a907b..2f51b28a2 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1787,32 +1787,38 @@ struct SpecializationContext // created. // auto valType = val->getFullType(); - if (auto extractExistentialType = as<IRExtractExistentialType>(valType)) + if (isCompileTimeConstantType(valType) && + isCompileTimeConstantType(oldParam->getFullType())) { - valType = extractExistentialType->getOperand(0)->getDataType(); - auto newParam = builder->createParam(valType); - newParams.add(newParam); - replacementVal = newParam; - } - else - { - auto newParam = builder->createParam(valType); - newParams.add(newParam); + if (auto extractExistentialType = as<IRExtractExistentialType>(valType)) + { + valType = extractExistentialType->getOperand(0)->getDataType(); + auto newParam = builder->createParam(valType); + newParams.add(newParam); + replacementVal = newParam; + } + else + { + auto newParam = builder->createParam(valType); + newParams.add(newParam); - // Within the body of the function we cannot just use `val` - // directly, because the existing code expects an existential - // value, including its witness table. - // - // Therefore we will create a `makeExistential(newParam, witnessTable)` - // in the body of the new function and use *that* as the replacement - // value for the original parameter (since it will have the - // correct existential type, and stores the right witness table). - // - auto newMakeExistential = builder->emitMakeExistential( - oldParam->getFullType(), - newParam, - witnessTable); - replacementVal = newMakeExistential; + // Within the body of the function we cannot just use `val` + // directly, because the existing code expects an existential + // value, including its witness table. + // + // Therefore we will create a `makeExistential(newParam, witnessTable)` + // in the body of the new function and use *that* as the replacement + // value for the original parameter (since it will have the + // correct existential type, and stores the right witness table). + // + auto newMakeExistential = builder->emitMakeExistential( + oldParam->getFullType(), + newParam, + witnessTable); + replacementVal = newMakeExistential; + } + cloneEnv.mapOldValToNew.add(oldParam, replacementVal); + continue; } } else if (auto oldWrapExistential = as<IRWrapExistential>(arg)) @@ -1837,24 +1843,32 @@ struct SpecializationContext newParam, oldWrapExistential->getSlotOperandCount(), oldWrapExistential->getSlotOperands()); - replacementVal = newWrapExistential; - } - else - { - // For parameters that don't have an existential type, - // there is nothing interesting to do. The new function - // will also have a parameter of the exact same type, - // and we'll use that instead of the original parameter. - // - auto newParam = builder->createParam(oldParam->getFullType()); - newParams.add(newParam); - replacementVal = newParam; + cloneEnv.mapOldValToNew.add(oldParam, newWrapExistential); + continue; } - // Whatever replacement value was constructed, we need to - // register it as the replacement for the original parameter. + // If we go here, then the parameter is either not an existential type, + // or the argument/parameter is not specialized yet. + // + // For first case there is nothing interesting to do. The new function + // will also have a parameter of the exact same type, and we'll use that + // instead of the original parameter. + // + // + // For the second case if the argument/parameter is not specialized yet, don't + // aggressively specialize the parameter. + // + // If we specialize the parameter type too early, we will lose the opportunity + // to specialize the callee later. The principal is to always let the + // specialization happen at the same time for both on argument and parameter. // - cloneEnv.mapOldValToNew.add(oldParam, replacementVal); + // Note we should not use `createParam` here, because this call won't assign the + // parent to the new parameter, therefore during the cloning process, some + // existential related IR inst could be hoisted to the global scope, which is + // unexpected. Instead, we should use cloneInst here, such that the new + // parameter will be inserted into the function scope. + auto newParam = (IRParam*)cloneInst(&cloneEnv, builder, oldParam); + newParams.add(newParam); } // The above steps have accomplished the "first phase" diff --git a/tests/bugs/gh-6589.slang b/tests/bugs/gh-6589.slang new file mode 100644 index 000000000..9433f510e --- /dev/null +++ b/tests/bugs/gh-6589.slang @@ -0,0 +1,63 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// This is a test that checks that we can apply partial specialization to a function +// we won't specialize the function parameters too aggressively. Instead, we will specialize +// the parameters at the same time of specializing the arguments. Otherwise, we could lose +// the chance to specialize the argument. +// +// In this test, `matrix_vector_interfaces` will be fully specialized, otherwise the compile +// will fail because we don't allow opaque type in the existential type. So as long as the target +// spirv code can be generated, we are good. + +// CHECK: %main +public interface ITensor<T : IDifferentiable, let D : int> +{ + public T get(int idx); + +} + +public interface IRWTensor<T : IDifferentiable, let D : int> : ITensor<T, D> +{ +} + + +public struct RWTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D> +{ + public RWStructuredBuffer<T> buffer; + public T get(int idx) { return buffer[idx]; } +} + +public struct GradInOutTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D> +{ + public RWTensor<T, D> primal; + public T get(int idx) { return primal.get(idx); } +} + +struct CallData +{ + GradInOutTensor<float, 3> weights; + GradInOutTensor<float, 2> biases; + RWStructuredBuffer<float> _result; +} +ParameterBlock<CallData> call_data; + +float matrix_vector_interfaces(ITensor<float, 2> weights, ITensor<float, 1> biases) +{ + return weights.get(0); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void main(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float _result; + GradInOutTensor<float, 2> weights; + GradInOutTensor<float, 1> biases; + + weights.primal.buffer = call_data.weights.primal.buffer; + biases.primal.buffer = call_data.biases.primal.buffer; + + _result = matrix_vector_interfaces(weights, biases); + + call_data._result[0] = _result; +} |
