summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-26 19:42:15 -0700
committerGitHub <noreply@github.com>2024-07-26 19:42:15 -0700
commit7e2bc8e06f61d554bae9bbebc1db0302eb3f1d8a (patch)
tree0f10e4a45cb81af2908da61743a4518de27748e2 /source
parentc0bff66541302309ff4833e8d4ae2eba1561498a (diff)
Allow passing sized array to unsized array parameter. (#4744)
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-check-conversion.cpp25
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit.cpp14
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-check-unsupported-inst.cpp27
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-peephole.cpp8
-rw-r--r--source/slang/slang-ir-specialize-arrays.cpp47
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp75
-rw-r--r--source/slang/slang-ir-specialize-function-call.h2
-rw-r--r--source/slang/slang-ir.cpp1
13 files changed, 195 insertions, 16 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index c75d4735b..7d4f303ef 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1040,8 +1040,8 @@ __generic<T, let N:int>
__magic_type(ArrayExpressionType)
struct Array : IArray<T>
{
- [ForceInline]
- int getCount() { return N; }
+ __intrinsic_op($(kIROp_GetArrayLength))
+ int getCount();
__subscript(int index) -> T
{
@@ -1049,7 +1049,6 @@ struct Array : IArray<T>
get;
}
}
-
/// An `N` component vector with elements of type `T`.
__generic<T = float, let N : int = 4>
__magic_type(VectorExpressionType)
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index f709345d5..e94174570 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -96,6 +96,7 @@ namespace Slang
kConversionCost_GenericParamUpcast = 1,
kConversionCost_UnconstraintGenericParam = 20,
+ kConversionCost_SizedArrayToUnsizedArray = 30,
// Convert between matrices of different layout
kConversionCost_MatrixLayout = 5,
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 00625d5f4..fafefa9dd 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -750,6 +750,31 @@ namespace Slang
return true;
}
+ // Allow implicit conversion from sized array to unsized array when
+ // calling a function.
+ // Note: we implement the logic here instead of an implicit_conversion
+ // intrinsic in the stdlib because we only want to allow this conversion
+ // when calling a function.
+ //
+ if (site == CoercionSite::Argument)
+ {
+ if (auto fromArrayType = as<ArrayExpressionType>(fromType))
+ {
+ if (auto toArrayType = as<ArrayExpressionType>(toType))
+ {
+ if (fromArrayType->getElementType()->equals(toArrayType->getElementType())
+ && toArrayType->isUnsized())
+ {
+ if (outToExpr)
+ *outToExpr = fromExpr;
+ if (outCost)
+ *outCost = kConversionCost_SizedArrayToUnsizedArray;
+ return true;
+ }
+ }
+ }
+ }
+
// Another important case is when either the "to" or "from" type
// represents an error. In such a case we must have already
// reporeted the error, so it is better to allow the conversion
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index fd7bd4a65..fe1a25b39 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -857,6 +857,7 @@ DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$
DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.")
DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.")
DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
+DIAGNOSTIC(56002, Error, attemptToQuerySizeOfUnsizedArray, "cannot obtain the size of an unsized array.")
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 03d1b932c..8f7b5f66f 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -947,13 +947,13 @@ Result linkAndOptimizeIR(
specializeResourceUsage(codeGenContext, irModule);
specializeFuncsForBufferLoadArgs(codeGenContext, irModule);
- // For GLSL targets, we also want to specialize calls to functions that
- // takes array parameters if possible, to avoid performance issues on
- // those platforms.
- if (isKhronosTarget(targetRequest))
- {
- specializeArrayParameters(codeGenContext, irModule);
- }
+ // We also want to specialize calls to functions that
+ // takes unsized array parameters if possible.
+ // Moreover, for Khronos targets, we also want to specialize calls to functions
+ // that takes arrays/structs containing arrays as parameters with the actual
+ // global array object to avoid loading big arrays into SSA registers, which seems
+ // to cause performance issues.
+ specializeArrayParameters(codeGenContext, irModule);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION");
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 492ff0ca1..0a9b2d691 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1943,6 +1943,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_DebugLine:
case kIROp_DebugVar:
case kIROp_DebugValue:
+ case kIROp_GetArrayLength:
+ case kIROp_SizeOf:
+ case kIROp_AlignOf:
return transcribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
diff --git a/source/slang/slang-ir-check-unsupported-inst.cpp b/source/slang/slang-ir-check-unsupported-inst.cpp
index c89928af5..2a1c0e325 100644
--- a/source/slang/slang-ir-check-unsupported-inst.cpp
+++ b/source/slang/slang-ir-check-unsupported-inst.cpp
@@ -42,6 +42,23 @@ namespace Slang
}
}
+ void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink)
+ {
+ SLANG_UNUSED(target);
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_GetArrayLength:
+ sink->diagnose(inst, Diagnostics::attemptToQuerySizeOfUnsizedArray);
+ break;
+ }
+ }
+ }
+ }
+
void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
{
HashSet<IRFunc*> checkedFuncsForRecursionDetection;
@@ -62,6 +79,16 @@ namespace Slang
case kIROp_Func:
if (!isCPUTarget(target))
checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink);
+ checkUnsupportedInst(target, as<IRFunc>(globalInst), sink);
+ break;
+ case kIROp_Generic:
+ {
+ auto generic = as<IRGeneric>(globalInst);
+ auto innerFunc = as<IRFunc>(findGenericReturnVal(generic));
+ if (innerFunc)
+ checkUnsupportedInst(target, innerFunc, sink);
+ break;
+ }
default:
break;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 74cb534f9..07694b066 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1108,7 +1108,7 @@ INST(TreatAsDynamicUniform, TreatAsDynamicUniform, 1, 0)
INST(SizeOf, sizeOf, 1, 0)
INST(AlignOf, alignOf, 1, 0)
-
+INST(GetArrayLength, GetArrayLength, 1, 0)
INST(IsType, IsType, 3, 0)
INST(TypeEquals, TypeEquals, 2, 0)
INST(IsInt, IsInt, 1, 0)
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index db867fe7d..232633d69 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -283,6 +283,14 @@ struct PeepholeContext : InstPassBase
changed = true;
}
break;
+ case kIROp_GetArrayLength:
+ if (auto arrayType = as<IRArrayType>(inst->getOperand(0)->getDataType()))
+ {
+ inst->replaceUsesWith(arrayType->getElementCount());
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ break;
case kIROp_GetResultError:
if (inst->getOperand(0)->getOp() == kIROp_MakeResultError)
{
diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp
index 3f42fb4b0..9bc0042db 100644
--- a/source/slang/slang-ir-specialize-arrays.cpp
+++ b/source/slang/slang-ir-specialize-arrays.cpp
@@ -13,8 +13,11 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
// This pass is intended to specialize functions
// with struct parameters that has array fields
// to avoid performance problems for GLSL targets.
-
// Returns true if `type` is an `IRStructType` with array-typed fields.
+ // It will also specialize functions with unsized array parameters into
+ // sized arrays, if the function is called with an argument that has a
+ // sized array type.
+ //
bool isStructTypeWithArray(IRType* type)
{
if (auto structType = as<IRStructType>(type))
@@ -38,8 +41,47 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
bool doesParamWantSpecialization(IRParam* param, IRInst* arg)
{
SLANG_UNUSED(arg);
- return isStructTypeWithArray(param->getDataType());
+ if (isKhronosTarget(codeGenContext->getTargetReq()))
+ return isStructTypeWithArray(param->getDataType());
+ return false;
}
+
+ bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+ {
+ auto paramType = param->getDataType();
+ auto argType = arg->getDataType();
+ if (auto outTypeBase = as<IROutTypeBase>(paramType))
+ {
+ paramType = outTypeBase->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ else if (auto refType = as<IRRefType>(paramType))
+ {
+ paramType = refType->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ else if (auto constRefType = as<IRConstRefType>(paramType))
+ {
+ paramType = constRefType->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ auto arrayType = as<IRUnsizedArrayType>(paramType);
+ if (!arrayType)
+ return false;
+ auto argArrayType = as<IRArrayType>(argType);
+ if (!argArrayType)
+ return false;
+ if (as<IRIntLit>(argArrayType->getElementCount()))
+ {
+ return true;
+ }
+ return false;
+ }
+
+ CodeGenContext* codeGenContext = nullptr;
};
void specializeArrayParameters(
@@ -47,6 +89,7 @@ void specializeArrayParameters(
IRModule* module)
{
ArrayParameterSpecializationCondition condition;
+ condition.codeGenContext = codeGenContext;
specializeFunctionCalls(codeGenContext, module, &condition);
}
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index c4928a230..a41ca1e99 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -67,6 +67,13 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam*
}
}
+bool FunctionCallSpecializeCondition::doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+{
+ SLANG_UNUSED(param);
+ SLANG_UNUSED(arg);
+ return false;
+}
+
struct FunctionParameterSpecializationContext
{
// This type implements a pass to specialize functions
@@ -209,7 +216,9 @@ struct FunctionParameterSpecializationContext
// If neither the parameter nor the argument wants specialization,
// then we need to keep looking.
//
- if(!doesParamWantSpecialization(param, arg))
+ auto paramWantSpecialization = doesParamWantSpecialization(param, arg);
+ auto paramTypeWantSpecialization = doesParamTypeWantSpecialization(param, arg);
+ if(!paramWantSpecialization && !paramTypeWantSpecialization)
continue;
// If we have run into a `param` or `arg` that wants specialization,
@@ -222,7 +231,7 @@ struct FunctionParameterSpecializationContext
// can bail out immediately because our second condition
// cannot be met.
//
- if(!isParamSuitableForSpecialization(param, arg))
+ if(paramWantSpecialization && !isParamSuitableForSpecialization(param, arg))
return false;
}
@@ -242,6 +251,11 @@ struct FunctionParameterSpecializationContext
return condition->doesParamWantSpecialization(param, arg);
}
+ bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+ {
+ return condition->doesParamTypeWantSpecialization(param, arg);
+ }
+
bool isParamSuitableForSpecialization(IRParam* param, IRInst* arg)
{
return condition->isParamSuitableForSpecialization(param, arg);
@@ -474,6 +488,11 @@ struct FunctionParameterSpecializationContext
// specialized callee based on this paramter.
//
ioInfo.newArgs.add(oldArg);
+
+ if (doesParamTypeWantSpecialization(oldParam, oldArg))
+ {
+ ioInfo.key.vals.add(oldArg->getDataType());
+ }
}
else
{
@@ -587,6 +606,30 @@ struct FunctionParameterSpecializationContext
}
}
+ // Wrap `argType` with a parameter direction type if `oldParam` has such a parameter direction type.
+ IRType* maybeWrapParameterDirectionType(IRParam* oldParam, IRType* argType)
+ {
+ IRType* paramType = oldParam->getDataType();
+ IRType* resultType = argType;
+ switch (paramType->getOp())
+ {
+ case kIROp_InOutType:
+ case kIROp_OutType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ resultType = getBuilder()->getPtrType(paramType->getOp(), argType, as<IRPtrTypeBase>(paramType)->getAddressSpace());
+ break;
+ }
+ if (auto rate = paramType->getRate())
+ {
+ IRBuilder builder(oldParam);
+ builder.setInsertAfter(resultType);
+ resultType = builder.getRateQualifiedType(rate, resultType);
+ }
+ return resultType;
+ }
+
IRInst* getSpecializedValueForParam(
FuncSpecializationInfo& ioInfo,
IRParam* oldParam,
@@ -601,7 +644,16 @@ struct FunctionParameterSpecializationContext
// that fills the same role as the old one, so we
// create it here.
//
- auto newParam = getBuilder()->createParam(oldParam->getFullType());
+ IRType* paramType = nullptr;
+ if (doesParamTypeWantSpecialization(oldParam, oldArg))
+ {
+ paramType = maybeWrapParameterDirectionType(oldParam, oldArg->getDataType());
+ }
+ else
+ {
+ paramType = oldParam->getFullType();
+ }
+ auto newParam = getBuilder()->createParam(paramType);
ioInfo.newParams.add(newParam);
// The new parameter will be used as the replacement
@@ -891,6 +943,23 @@ struct FunctionParameterSpecializationContext
//
addCallsToWorkListRec(newFunc);
+ // If one of the new parameters has a more specialized type,
+ // we need to update the type of load instructions from that
+ // parameter, if there are any.
+ for (auto newParam : funcInfo.newParams)
+ {
+ if (!as<IRParam>(newParam))
+ continue;
+ for (auto use = newParam->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (auto load = as<IRLoad>(user))
+ {
+ load->setFullType(as<IRPtrTypeBase>(newParam->getDataType())->getValueType());
+ }
+ }
+ }
+
simplifyFunc(codeGenContext->getTargetProgram(), newFunc, IRSimplificationOptions::getFast(codeGenContext->getTargetProgram()));
return newFunc;
diff --git a/source/slang/slang-ir-specialize-function-call.h b/source/slang/slang-ir-specialize-function-call.h
index 8fe113a98..092a5158d 100644
--- a/source/slang/slang-ir-specialize-function-call.h
+++ b/source/slang/slang-ir-specialize-function-call.h
@@ -15,6 +15,8 @@ namespace Slang
virtual bool doesParamWantSpecialization(IRParam* param, IRInst* arg) = 0;
virtual bool isParamSuitableForSpecialization(IRParam* param, IRInst* arg);
+
+ virtual bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg);
};
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index bb05a1c29..88065cedc 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -8250,6 +8250,7 @@ namespace Slang
case kIROp_TorchTensorGetView:
case kIROp_GetStringHash:
case kIROp_AllocateOpaqueHandle:
+ case kIROp_GetArrayLength:
return false;
case kIROp_ForwardDifferentiate: