diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-26 19:42:15 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-26 19:42:15 -0700 |
| commit | 7e2bc8e06f61d554bae9bbebc1db0302eb3f1d8a (patch) | |
| tree | 0f10e4a45cb81af2908da61743a4518de27748e2 /source | |
| parent | c0bff66541302309ff4833e8d4ae2eba1561498a (diff) | |
Allow passing sized array to unsized array parameter. (#4744)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-unsupported-inst.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-arrays.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-function-call.cpp | 75 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-function-call.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 1 |
13 files changed, 195 insertions, 16 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index c75d4735b..7d4f303ef 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1040,8 +1040,8 @@ __generic<T, let N:int> __magic_type(ArrayExpressionType) struct Array : IArray<T> { - [ForceInline] - int getCount() { return N; } + __intrinsic_op($(kIROp_GetArrayLength)) + int getCount(); __subscript(int index) -> T { @@ -1049,7 +1049,6 @@ struct Array : IArray<T> get; } } - /// An `N` component vector with elements of type `T`. __generic<T = float, let N : int = 4> __magic_type(VectorExpressionType) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index f709345d5..e94174570 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -96,6 +96,7 @@ namespace Slang kConversionCost_GenericParamUpcast = 1, kConversionCost_UnconstraintGenericParam = 20, + kConversionCost_SizedArrayToUnsizedArray = 30, // Convert between matrices of different layout kConversionCost_MatrixLayout = 5, diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 00625d5f4..fafefa9dd 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -750,6 +750,31 @@ namespace Slang return true; } + // Allow implicit conversion from sized array to unsized array when + // calling a function. + // Note: we implement the logic here instead of an implicit_conversion + // intrinsic in the stdlib because we only want to allow this conversion + // when calling a function. + // + if (site == CoercionSite::Argument) + { + if (auto fromArrayType = as<ArrayExpressionType>(fromType)) + { + if (auto toArrayType = as<ArrayExpressionType>(toType)) + { + if (fromArrayType->getElementType()->equals(toArrayType->getElementType()) + && toArrayType->isUnsized()) + { + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_SizedArrayToUnsizedArray; + return true; + } + } + } + } + // Another important case is when either the "to" or "from" type // represents an error. In such a case we must have already // reporeted the error, so it is better to allow the conversion diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index fd7bd4a65..fe1a25b39 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -857,6 +857,7 @@ DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$ DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.") DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.") DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") +DIAGNOSTIC(56002, Error, attemptToQuerySizeOfUnsizedArray, "cannot obtain the size of an unsized array.") DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 03d1b932c..8f7b5f66f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -947,13 +947,13 @@ Result linkAndOptimizeIR( specializeResourceUsage(codeGenContext, irModule); specializeFuncsForBufferLoadArgs(codeGenContext, irModule); - // For GLSL targets, we also want to specialize calls to functions that - // takes array parameters if possible, to avoid performance issues on - // those platforms. - if (isKhronosTarget(targetRequest)) - { - specializeArrayParameters(codeGenContext, irModule); - } + // We also want to specialize calls to functions that + // takes unsized array parameters if possible. + // Moreover, for Khronos targets, we also want to specialize calls to functions + // that takes arrays/structs containing arrays as parameters with the actual + // global array object to avoid loading big arrays into SSA registers, which seems + // to cause performance issues. + specializeArrayParameters(codeGenContext, irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION"); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 492ff0ca1..0a9b2d691 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1943,6 +1943,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_DebugLine: case kIROp_DebugVar: case kIROp_DebugValue: + case kIROp_GetArrayLength: + case kIROp_SizeOf: + case kIROp_AlignOf: return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/source/slang/slang-ir-check-unsupported-inst.cpp b/source/slang/slang-ir-check-unsupported-inst.cpp index c89928af5..2a1c0e325 100644 --- a/source/slang/slang-ir-check-unsupported-inst.cpp +++ b/source/slang/slang-ir-check-unsupported-inst.cpp @@ -42,6 +42,23 @@ namespace Slang } } + void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink) + { + SLANG_UNUSED(target); + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_GetArrayLength: + sink->diagnose(inst, Diagnostics::attemptToQuerySizeOfUnsizedArray); + break; + } + } + } + } + void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink) { HashSet<IRFunc*> checkedFuncsForRecursionDetection; @@ -62,6 +79,16 @@ namespace Slang case kIROp_Func: if (!isCPUTarget(target)) checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink); + checkUnsupportedInst(target, as<IRFunc>(globalInst), sink); + break; + case kIROp_Generic: + { + auto generic = as<IRGeneric>(globalInst); + auto innerFunc = as<IRFunc>(findGenericReturnVal(generic)); + if (innerFunc) + checkUnsupportedInst(target, innerFunc, sink); + break; + } default: break; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 74cb534f9..07694b066 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1108,7 +1108,7 @@ INST(TreatAsDynamicUniform, TreatAsDynamicUniform, 1, 0) INST(SizeOf, sizeOf, 1, 0) INST(AlignOf, alignOf, 1, 0) - +INST(GetArrayLength, GetArrayLength, 1, 0) INST(IsType, IsType, 3, 0) INST(TypeEquals, TypeEquals, 2, 0) INST(IsInt, IsInt, 1, 0) diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index db867fe7d..232633d69 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -283,6 +283,14 @@ struct PeepholeContext : InstPassBase changed = true; } break; + case kIROp_GetArrayLength: + if (auto arrayType = as<IRArrayType>(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(arrayType->getElementCount()); + maybeRemoveOldInst(inst); + changed = true; + } + break; case kIROp_GetResultError: if (inst->getOperand(0)->getOp() == kIROp_MakeResultError) { diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp index 3f42fb4b0..9bc0042db 100644 --- a/source/slang/slang-ir-specialize-arrays.cpp +++ b/source/slang/slang-ir-specialize-arrays.cpp @@ -13,8 +13,11 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition // This pass is intended to specialize functions // with struct parameters that has array fields // to avoid performance problems for GLSL targets. - // Returns true if `type` is an `IRStructType` with array-typed fields. + // It will also specialize functions with unsized array parameters into + // sized arrays, if the function is called with an argument that has a + // sized array type. + // bool isStructTypeWithArray(IRType* type) { if (auto structType = as<IRStructType>(type)) @@ -38,8 +41,47 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition bool doesParamWantSpecialization(IRParam* param, IRInst* arg) { SLANG_UNUSED(arg); - return isStructTypeWithArray(param->getDataType()); + if (isKhronosTarget(codeGenContext->getTargetReq())) + return isStructTypeWithArray(param->getDataType()); + return false; } + + bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) + { + auto paramType = param->getDataType(); + auto argType = arg->getDataType(); + if (auto outTypeBase = as<IROutTypeBase>(paramType)) + { + paramType = outTypeBase->getValueType(); + SLANG_ASSERT(as<IRPtrTypeBase>(argType)); + argType = as<IRPtrTypeBase>(argType)->getValueType(); + } + else if (auto refType = as<IRRefType>(paramType)) + { + paramType = refType->getValueType(); + SLANG_ASSERT(as<IRPtrTypeBase>(argType)); + argType = as<IRPtrTypeBase>(argType)->getValueType(); + } + else if (auto constRefType = as<IRConstRefType>(paramType)) + { + paramType = constRefType->getValueType(); + SLANG_ASSERT(as<IRPtrTypeBase>(argType)); + argType = as<IRPtrTypeBase>(argType)->getValueType(); + } + auto arrayType = as<IRUnsizedArrayType>(paramType); + if (!arrayType) + return false; + auto argArrayType = as<IRArrayType>(argType); + if (!argArrayType) + return false; + if (as<IRIntLit>(argArrayType->getElementCount())) + { + return true; + } + return false; + } + + CodeGenContext* codeGenContext = nullptr; }; void specializeArrayParameters( @@ -47,6 +89,7 @@ void specializeArrayParameters( IRModule* module) { ArrayParameterSpecializationCondition condition; + condition.codeGenContext = codeGenContext; specializeFunctionCalls(codeGenContext, module, &condition); } diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index c4928a230..a41ca1e99 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -67,6 +67,13 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam* } } +bool FunctionCallSpecializeCondition::doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) +{ + SLANG_UNUSED(param); + SLANG_UNUSED(arg); + return false; +} + struct FunctionParameterSpecializationContext { // This type implements a pass to specialize functions @@ -209,7 +216,9 @@ struct FunctionParameterSpecializationContext // If neither the parameter nor the argument wants specialization, // then we need to keep looking. // - if(!doesParamWantSpecialization(param, arg)) + auto paramWantSpecialization = doesParamWantSpecialization(param, arg); + auto paramTypeWantSpecialization = doesParamTypeWantSpecialization(param, arg); + if(!paramWantSpecialization && !paramTypeWantSpecialization) continue; // If we have run into a `param` or `arg` that wants specialization, @@ -222,7 +231,7 @@ struct FunctionParameterSpecializationContext // can bail out immediately because our second condition // cannot be met. // - if(!isParamSuitableForSpecialization(param, arg)) + if(paramWantSpecialization && !isParamSuitableForSpecialization(param, arg)) return false; } @@ -242,6 +251,11 @@ struct FunctionParameterSpecializationContext return condition->doesParamWantSpecialization(param, arg); } + bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) + { + return condition->doesParamTypeWantSpecialization(param, arg); + } + bool isParamSuitableForSpecialization(IRParam* param, IRInst* arg) { return condition->isParamSuitableForSpecialization(param, arg); @@ -474,6 +488,11 @@ struct FunctionParameterSpecializationContext // specialized callee based on this paramter. // ioInfo.newArgs.add(oldArg); + + if (doesParamTypeWantSpecialization(oldParam, oldArg)) + { + ioInfo.key.vals.add(oldArg->getDataType()); + } } else { @@ -587,6 +606,30 @@ struct FunctionParameterSpecializationContext } } + // Wrap `argType` with a parameter direction type if `oldParam` has such a parameter direction type. + IRType* maybeWrapParameterDirectionType(IRParam* oldParam, IRType* argType) + { + IRType* paramType = oldParam->getDataType(); + IRType* resultType = argType; + switch (paramType->getOp()) + { + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + argType = as<IRPtrTypeBase>(argType)->getValueType(); + resultType = getBuilder()->getPtrType(paramType->getOp(), argType, as<IRPtrTypeBase>(paramType)->getAddressSpace()); + break; + } + if (auto rate = paramType->getRate()) + { + IRBuilder builder(oldParam); + builder.setInsertAfter(resultType); + resultType = builder.getRateQualifiedType(rate, resultType); + } + return resultType; + } + IRInst* getSpecializedValueForParam( FuncSpecializationInfo& ioInfo, IRParam* oldParam, @@ -601,7 +644,16 @@ struct FunctionParameterSpecializationContext // that fills the same role as the old one, so we // create it here. // - auto newParam = getBuilder()->createParam(oldParam->getFullType()); + IRType* paramType = nullptr; + if (doesParamTypeWantSpecialization(oldParam, oldArg)) + { + paramType = maybeWrapParameterDirectionType(oldParam, oldArg->getDataType()); + } + else + { + paramType = oldParam->getFullType(); + } + auto newParam = getBuilder()->createParam(paramType); ioInfo.newParams.add(newParam); // The new parameter will be used as the replacement @@ -891,6 +943,23 @@ struct FunctionParameterSpecializationContext // addCallsToWorkListRec(newFunc); + // If one of the new parameters has a more specialized type, + // we need to update the type of load instructions from that + // parameter, if there are any. + for (auto newParam : funcInfo.newParams) + { + if (!as<IRParam>(newParam)) + continue; + for (auto use = newParam->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto load = as<IRLoad>(user)) + { + load->setFullType(as<IRPtrTypeBase>(newParam->getDataType())->getValueType()); + } + } + } + simplifyFunc(codeGenContext->getTargetProgram(), newFunc, IRSimplificationOptions::getFast(codeGenContext->getTargetProgram())); return newFunc; diff --git a/source/slang/slang-ir-specialize-function-call.h b/source/slang/slang-ir-specialize-function-call.h index 8fe113a98..092a5158d 100644 --- a/source/slang/slang-ir-specialize-function-call.h +++ b/source/slang/slang-ir-specialize-function-call.h @@ -15,6 +15,8 @@ namespace Slang virtual bool doesParamWantSpecialization(IRParam* param, IRInst* arg) = 0; virtual bool isParamSuitableForSpecialization(IRParam* param, IRInst* arg); + + virtual bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg); }; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index bb05a1c29..88065cedc 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8250,6 +8250,7 @@ namespace Slang case kIROp_TorchTensorGetView: case kIROp_GetStringHash: case kIROp_AllocateOpaqueHandle: + case kIROp_GetArrayLength: return false; case kIROp_ForwardDifferentiate: |
