diff options
| -rw-r--r-- | source/slang/slang-check-type.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-clone.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 75 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 49 | ||||
| -rw-r--r-- | tests/spirv/spec-constant-generic.slang | 53 |
8 files changed, 150 insertions, 49 deletions
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 172d09ac2..d32903175 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -153,7 +153,7 @@ IntVal* SemanticsVisitor::ExtractGenericArgInteger( genericParamType ? IntegerConstantExpressionCoercionType::SpecificType : IntegerConstantExpressionCoercionType::AnyInteger, genericParamType, - ConstantFoldingKind::LinkTime, + ConstantFoldingKind::SpecializationConstant, sink); if (val) return val; diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index 5bb1c1210..1a020ec26 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -93,6 +93,9 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns auto newOperand = findCloneForOperand(env, oldOperand); newOperands[ii] = newOperand; + + if (isArithmeticInst(oldInst)) + newType = maybeAddRateType(builder, newOperand->getFullType(), newType); } // Finally we create the inst with the updated operands. diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 2f51b28a2..266c1aa99 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -101,7 +101,13 @@ struct SpecializationContext case kIROp_IntCast: case kIROp_FloatCast: case kIROp_Select: - return true; + { + if (isSpecConstRateType(inst->getFullType())) + { + return false; + } + return true; + } default: return false; } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 9d8773237..c8faec73b 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2261,16 +2261,75 @@ bool isSpecConstRateType(IRType* type) } return false; } -void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst) + +IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType) { - IRInst* moduleInst = builder->getModule()->getModuleInst(); - UInt operandCount = inst->getOperandCount(); - for (UInt ii = 0; ii < operandCount; ++ii) + if (as<IRRateQualifiedType>(oldType)) { - auto operand = inst->getOperand(ii); - if (operand->parent != moduleInst) - hoistInstAndOperandsToGlobal(builder, operand); + return oldType; } - inst->insertAt(IRInsertLoc::atStart(moduleInst)); + + if (isSpecConstRateType(rateQulifiedType)) + { + return builder->getRateQualifiedType(builder->getSpecConstRate(), oldType); + } + return oldType; +} + +bool isArithmeticInst(IROp op) +{ + switch (op) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Neg: + case kIROp_Not: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Less: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_Greater: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitNot: + case kIROp_BitCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Select: + return true; + default: + return false; + } +} +bool isArithmeticInst(IRInst* inst) +{ + return isArithmeticInst(inst->getOp()); +} + +bool isInstHoistable(IROp op, IRType* type) +{ + if ((getIROpInfo(op).flags & kIROpFlag_Hoistable)) + { + return true; + } + + if (isArithmeticInst(op)) + { + if (type && isSpecConstRateType(type)) + { + return true; + } + } + return false; } } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 900e22c76..1e5a5eb2a 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -391,6 +391,10 @@ bool isFirstBlock(IRInst* inst); bool isSpecConstRateType(IRType* type); void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst); +IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType); +bool isArithmeticInst(IRInst* inst); +bool isArithmeticInst(IROp op); +bool isInstHoistable(IROp op, IRType* type); } // namespace Slang #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9c4cb98c0..c44196bc5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1821,7 +1821,7 @@ IRInst* IRBuilder::_createInst( m_dedupContext->getInstReplacementMap().tryGetValue(type, instReplacement); type = (IRType*)instReplacement; - if (getIROpInfo(op).flags & kIROpFlag_Hoistable) + if (isInstHoistable(op, type)) { return _findOrEmitHoistableInst( type, @@ -2527,7 +2527,8 @@ static void addGlobalValue(IRBuilder* builder, IRInst* value) // if (value->parent) { - SLANG_ASSERT(getIROpInfo(value->getOp()).isHoistable()); + SLANG_ASSERT( + getIROpInfo(value->getOp()).isHoistable() || isSpecConstRateType(value->getFullType())); return; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index af285f221..f8946f5dc 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1552,17 +1552,6 @@ static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase return context->thisTypeWitness == nullptr; } - -// -static void maybePropagateRate(IRBuilder* builder, IRType* rateQulifiedType, IRInst* inst) -{ - if (isSpecConstRateType(rateQulifiedType)) - { - inst->setFullType( - builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType())); - } -} - struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo> { IRGenContext* context; @@ -1596,14 +1585,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower specConstRateType = loweredArg.val->getFullType(); } auto funcType = lowerType(context, val->getFuncType()); - auto resVal = emitCallToDeclRef( - context, - as<IRFuncType>(funcType)->getResultType(), - val->getFuncDeclRef(), - funcType, - args, - tryEnv); - maybePropagateRate(getBuilder(), specConstRateType, resVal.val); + auto funcResType = maybeAddRateType( + getBuilder(), + specConstRateType, + as<IRFuncType>(funcType)->getResultType()); + auto resVal = + emitCallToDeclRef(context, funcResType, val->getFuncDeclRef(), funcType, args, tryEnv); return resVal; } @@ -1613,8 +1600,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); auto type = lowerType(context, val->getType()); + type = maybeAddRateType(getBuilder(), baseVal.val->getFullType(), type); auto resVal = LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val)); - maybePropagateRate(getBuilder(), baseVal.val->getFullType(), resVal.val); return resVal; } @@ -1641,12 +1628,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto factorVal = lowerVal(context, factor->getParam()).val; for (IntegerLiteralValue i = 0; i < factor->getPower(); i++) { - termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal); + termVal = irBuilder->emitMul(factorVal->getFullType(), termVal, factorVal); } - maybePropagateRate(getBuilder(), factorVal->getFullType(), termVal); } - resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal); - maybePropagateRate(getBuilder(), termVal->getFullType(), resultVal); + resultVal = irBuilder->emitAdd(termVal->getFullType(), resultVal, termVal); } return LoweredValInfo::simple(resultVal); } @@ -2076,19 +2061,9 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto elementType = lowerType(context, type->getElementType()); if (!type->isUnsized()) { - IRInst* elementCount = nullptr; - auto sizeVal = type->getElementCount(); - auto sharedContext = context->shared; - if (!sharedContext->mapSpecConstValToIRInst.tryGetValue(sizeVal, elementCount)) - { - elementCount = lowerSimpleVal(context, sizeVal); - if (isSpecConstRateType(elementCount->getFullType())) - { - sharedContext->mapSpecConstValToIRInst.add(sizeVal, elementCount); - hoistInstAndOperandsToGlobal(getBuilder(), elementCount); - } - } - return getBuilder()->getArrayType(elementType, elementCount); + return getBuilder()->getArrayType( + elementType, + lowerSimpleVal(context, type->getElementCount())); } else { diff --git a/tests/spirv/spec-constant-generic.slang b/tests/spirv/spec-constant-generic.slang new file mode 100644 index 000000000..65eed2810 --- /dev/null +++ b/tests/spirv/spec-constant-generic.slang @@ -0,0 +1,53 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type + +// CHECK: %[[C0:[0-9A-Za-z_]+]] = OpSpecConstant %int 32 +// CHECK: %[[C1:[0-9A-Za-z_]+]] = OpSpecConstant %int 2 +// CHECK: %[[COP0:[0-9A-Za-z_]+]] = OpSpecConstantOp %int SDiv %[[C0]] %[[C1]] +// CHECK: %[[ARR_TYPE:[0-9A-Za-z_]+]] = OpTypeArray %float %[[COP0]] +// CHECK: %[[PT_TYPE:[0-9A-Za-z_]+]] = OpTypePointer Function %[[ARR_TYPE]] + +[SpecializationConstant] +const int constValue0 = 32; + +[SpecializationConstant] +const int constValue1 = 2; + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +void func(out float buffer[constValue0 / constValue1]) +{ + for (uint i = 0; i < constValue0 / constValue1; i++) + { + buffer[i] = i; + } +} + +struct MyStruct<let N: int> +{ + float buffer[N / constValue1]; +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + // This test checks we can use spec constants for generic arguments, and also + // we can show that the array size is computed correctly. + // The function call shows that the two arrays are the same type. + MyStruct<constValue0> s; + // CHECK: OpVariable %[[PT_TYPE]] Function + + func(s.buffer); + + float temp = 0.0f; + for (uint i = 0; i < constValue0 / constValue1; i++) + { + temp += s.buffer[i] * 2; + } + + // Result will be (0 + localConst-1) * localConst = 15 * 16 = 240 + outputBuffer[0] = temp; + // BUF: 240 +} |
