diff options
Diffstat (limited to 'source')
23 files changed, 363 insertions, 154 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 5abef94b3..893d5e6d7 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -347,6 +347,7 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element { if (!elementCount) elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength); + if (elementCount->getType() != getIntType()) { // Canonicalize constant elementCount to int. diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index 1881f1b3c..89fa52b09 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -41,7 +41,7 @@ DeclRefBase* _getDeclRefFromVal(Val* val) { if (auto declRefType = as<DeclRefType>(val)) return declRefType->getDeclRef(); - else if (auto genParamIntVal = as<GenericParamIntVal>(val)) + else if (auto genParamIntVal = as<DeclRefIntVal>(val)) return genParamIntVal->getDeclRef(); else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val)) return declaredSubtypeWitness->getDeclRef(); diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index efb87b831..1cdca0440 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -176,9 +176,9 @@ void ConstantIntVal::_toTextOverride(StringBuilder& out) out << getValue(); } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -void GenericParamIntVal::_toTextOverride(StringBuilder& out) +void DeclRefIntVal::_toTextOverride(StringBuilder& out) { Name* name = getDeclRef().getName(); if (name) @@ -248,7 +248,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet return paramVal; } -Val* GenericParamIntVal::_substituteImplOverride( +Val* DeclRefIntVal::_substituteImplOverride( ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) @@ -259,12 +259,12 @@ Val* GenericParamIntVal::_substituteImplOverride( return this; } -bool GenericParamIntVal::_isLinkTimeValOverride() +bool DeclRefIntVal::_isLinkTimeValOverride() { return getDeclRef().getDecl()->hasModifier<ExternModifier>(); } -Val* GenericParamIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) +Val* DeclRefIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) { auto name = getMangledName(getCurrentASTBuilder(), getDeclRef().declRefBase); IntVal* v; diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index cdfb0b51f..2b4c7ed22 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -168,7 +168,7 @@ class ConstantIntVal : public IntVal // The logical "value" of a reference to a generic value parameter FIDDLE() -class GenericParamIntVal : public IntVal +class DeclRefIntVal : public IntVal { FIDDLE(...) DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); } @@ -177,10 +177,7 @@ class GenericParamIntVal : public IntVal void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) - { - setOperands(inType, inDeclRef); - } + DeclRefIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) { setOperands(inType, inDeclRef); } bool _isLinkTimeValOverride(); Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); @@ -319,9 +316,9 @@ public: // for sorting only. bool operator<(const PolynomialIntValFactor& other) const { - if (auto thisGenParam = as<GenericParamIntVal>(getParam())) + if (auto thisGenParam = as<DeclRefIntVal>(getParam())) { - if (auto thatGenParam = as<GenericParamIntVal>(other.getParam())) + if (auto thatGenParam = as<DeclRefIntVal>(other.getParam())) { if (thisGenParam->equals(thatGenParam)) return getPower() < other.getPower(); @@ -336,7 +333,7 @@ public: } else { - if (const auto thatGenParam = as<GenericParamIntVal>(other.getParam())) + if (const auto thatGenParam = as<DeclRefIntVal>(other.getParam())) { return false; } @@ -347,9 +344,9 @@ public: // for sorting only. bool operator==(const PolynomialIntValFactor& other) const { - if (auto thisGenParam = as<GenericParamIntVal>(getParam())) + if (auto thisGenParam = as<DeclRefIntVal>(getParam())) { - if (auto thatGenParam = as<GenericParamIntVal>(other.getParam())) + if (auto thatGenParam = as<DeclRefIntVal>(other.getParam())) { if (thisGenParam->equals(thatGenParam) && getPower() == other.getPower()) return true; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 642a4bf6a..6f9191135 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -819,7 +819,7 @@ bool SemanticsVisitor::TryUnifyVals( { if (const auto c = as<TypeCastIntVal>(i)) i = as<IntVal>(c->getBase()); - return as<GenericParamIntVal>(i); + return as<DeclRefIntVal>(i); }; auto fstParam = paramUnderCast(fstInt); auto sndParam = paramUnderCast(sndInt); @@ -1196,7 +1196,7 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam( { param = as<IntVal>(typeCastParam->getBase()); } - auto intParam = as<GenericParamIntVal>(param); + auto intParam = as<DeclRefIntVal>(param); if (!intParam) return; for (auto c : constraints.constraints) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1e524e27f..dbd52ebea 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3921,7 +3921,7 @@ bool SemanticsVisitor::doesGenericSignatureMatchRequirement( auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>(); SLANG_ASSERT(satisfyingValueParamDeclRef); - auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>( + auto satisfyingVal = m_astBuilder->getOrCreate<DeclRefIntVal>( requiredValueParamDeclRef.getDecl()->getType(), satisfyingValueParamDeclRef); satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef; @@ -8513,7 +8513,7 @@ List<Val*> getDefaultSubstitutionArgs( if (semantics) semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup); - args.add(astBuilder->getOrCreate<GenericParamIntVal>( + args.add(astBuilder->getOrCreate<DeclRefIntVal>( genericValueParamDecl->getType(), astBuilder->getDirectDeclRef(genericValueParamDecl))); } @@ -11769,7 +11769,7 @@ void checkDerivativeAttributeImpl( appExpr->arguments.add(baseTypeExpr); } - else if (auto genericValParam = as<GenericParamIntVal>(arg)) + else if (auto genericValParam = as<DeclRefIntVal>(arg)) { auto declRef = genericValParam->getDeclRef(); appExpr->arguments.add( diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 48f32952b..d151d37be 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1972,9 +1972,14 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( // The values of specialization constants aren't known at compile time even // if they're marked `const`. - if (decl->hasModifier<SpecializationConstantAttribute>() || - decl->hasModifier<VkConstantIdAttribute>()) - return nullptr; + if ((decl->hasModifier<SpecializationConstantAttribute>() || + decl->hasModifier<VkConstantIdAttribute>()) && + kind == ConstantFoldingKind::SpecializationConstant) + { + return m_astBuilder->getOrCreate<DeclRefIntVal>( + declRef.substitute(m_astBuilder, declRef.getDecl()->getType()), + declRef); + } if (decl->hasModifier<ExternModifier>()) { @@ -1982,7 +1987,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( if (kind == ConstantFoldingKind::CompileTime) return nullptr; // But if we are OK with link-time constants, we can still fold it into a val. - auto rs = m_astBuilder->getOrCreate<GenericParamIntVal>( + auto rs = m_astBuilder->getOrCreate<DeclRefIntVal>( declRef.substitute(m_astBuilder, declRef.getDecl()->getType()), declRef); return rs; @@ -2067,7 +2072,7 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr( if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { - Val* valResult = m_astBuilder->getOrCreate<GenericParamIntVal>( + Val* valResult = m_astBuilder->getOrCreate<DeclRefIntVal>( declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()), genericValParamRef); valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); @@ -2383,7 +2388,7 @@ Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr) subscriptExpr->indexExprs[0], IntegerConstantExpressionCoercionType::AnyInteger, nullptr, - ConstantFoldingKind::LinkTime); + ConstantFoldingKind::SpecializationConstant); // Validate that array size is greater than zero if (auto constElementCount = as<ConstantIntVal>(elementCount)) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index a910a3722..6c9a0409d 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2110,6 +2110,7 @@ public: { CompileTime, LinkTime, + SpecializationConstant }; Expr* checkExpressionAndExpectIntegerConstant( Expr* expr, diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index db753713b..172d09ac2 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -453,9 +453,9 @@ bool SemanticsVisitor::ValuesAreEqual(IntVal* left, IntVal* right) } } - if (auto leftVar = as<GenericParamIntVal>(left)) + if (auto leftVar = as<DeclRefIntVal>(left)) { - if (auto rightVar = as<GenericParamIntVal>(right)) + if (auto rightVar = as<DeclRefIntVal>(right)) { return leftVar->getDeclRef().equals(rightVar->getDeclRef()); } diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index d2e68ccc8..50fd739cb 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -782,7 +782,7 @@ void DocMarkdownWriter::writeExtensionConditions( { genericParamDecl = extTypeParamDecl.getDecl(); } - else if (auto extValueParamVal = as<GenericParamIntVal>(arg)) + else if (auto extValueParamVal = as<DeclRefIntVal>(arg)) { genericParamDecl = extValueParamVal->getDeclRef().getDecl(); } diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 096e7d8bc..32d3ba7c3 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -775,6 +775,147 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex m_operandStack.setCount(operandsStartIndex); } + SpvOp _specConstantOpcodeConvert(IROp irOpCode, IRBasicType* basicType) + { + SpvOp opCode = SpvOpUndef; + opCode = _arithmeticOpCodeConvert(irOpCode, basicType); + if (opCode == SpvOpUndef) + { + switch (irOpCode) + { + case kIROp_IntCast: + { + auto typeStyle = getTypeStyle(basicType->getBaseType()); + if (typeStyle == kIROp_FloatType) + { + return SpvOpConvertFToU; + } + else if (typeStyle == kIROp_IntType) + { + return SpvOpUConvert; + } + break; + } + default: + break; + } + return opCode; + } + return opCode; + } + + SpvOp _arithmeticOpCodeConvert(IROp irOpCode, IRBasicType* basicType) + { + bool isFloatingPoint = false; + bool isBool = false; + switch (basicType->getBaseType()) + { + case BaseType::Float: + case BaseType::Double: + case BaseType::Half: + isFloatingPoint = true; + break; + case BaseType::Bool: + isBool = true; + break; + default: + break; + } + bool isSigned = isSignedType(basicType); + SpvOp opCode = SpvOpUndef; + switch (irOpCode) + { + case kIROp_Add: + opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; + break; + case kIROp_Sub: + opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub; + break; + case kIROp_Mul: + opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul; + break; + case kIROp_Div: + opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv; + break; + case kIROp_IRem: + opCode = isSigned ? SpvOpSRem : SpvOpUMod; + break; + case kIROp_FRem: + opCode = SpvOpFRem; + break; + case kIROp_Less: + opCode = isFloatingPoint ? SpvOpFOrdLessThan + : isSigned ? SpvOpSLessThan + : SpvOpULessThan; + break; + case kIROp_Leq: + opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual + : isSigned ? SpvOpSLessThanEqual + : SpvOpULessThanEqual; + break; + case kIROp_Eql: + opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; + break; + case kIROp_Neq: + opCode = isFloatingPoint ? SpvOpFUnordNotEqual + : isBool ? SpvOpLogicalNotEqual + : SpvOpINotEqual; + break; + case kIROp_Geq: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual + : isSigned ? SpvOpSGreaterThanEqual + : SpvOpUGreaterThanEqual; + break; + case kIROp_Greater: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThan + : isSigned ? SpvOpSGreaterThan + : SpvOpUGreaterThan; + break; + case kIROp_Neg: + opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; + break; + case kIROp_And: + opCode = SpvOpLogicalAnd; + break; + case kIROp_Or: + opCode = SpvOpLogicalOr; + break; + case kIROp_Not: + opCode = SpvOpLogicalNot; + break; + case kIROp_BitAnd: + if (isBool) + opCode = SpvOpLogicalAnd; + else + opCode = SpvOpBitwiseAnd; + break; + case kIROp_BitOr: + if (isBool) + opCode = SpvOpLogicalOr; + else + opCode = SpvOpBitwiseOr; + break; + case kIROp_BitXor: + if (isBool) + opCode = SpvOpLogicalNotEqual; + else + opCode = SpvOpBitwiseXor; + break; + case kIROp_BitNot: + if (isBool) + opCode = SpvOpLogicalNot; + else + opCode = SpvOpNot; + break; + case kIROp_Rsh: + opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; + break; + case kIROp_Lsh: + opCode = SpvOpShiftLeftLogical; + break; + } + return opCode; + } /// Ensure that an instruction has been emitted SpvInst* ensureInst(IRInst* irInst) { @@ -1972,8 +2113,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex as<IRDebugInlinedAt>(inst)); default: { - if (as<IRSPIRVAsmOperand>(inst)) + if (isSpecConstRateType(inst->getFullType())) + return emitSpecializationConstantOp(inst); + + else if (as<IRSPIRVAsmOperand>(inst)) return nullptr; + String e = "Unhandled global inst in spirv-emit:\n" + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); SLANG_UNIMPLEMENTED_X(e.begin()); @@ -2756,6 +2901,66 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + SpvInst* emitSpecializationConstantOp(IRInst* inst) + { + SpvInst* spv = nullptr; + if (m_mapIRInstToSpvInst.tryGetValue(inst, spv)) + return spv; + + // For each OpSpecConstantOp, the operand must be: + // 1. A specialization constant + // 2. A literal constant + // 3. Another OpSpecConstantOp + + // For 1 and 2, we can just emit the specialization constant or literal constant. + if (auto param = as<IRGlobalParam>(inst)) + { + auto layout = getVarLayout(param); + if (layout) + { + if (auto offset = + layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant)) + { + return emitSpecializationConstant(param, offset); + } + } + SLANG_UNREACHABLE("Non specialization constant used in OpSpecConstantOp\n"); + } + else if (as<IRConstant>(inst)) + { + // We need to emit the constant as a specialization constant + return emitLit(inst); + } + + IRType* type = inst->getOperand(0)->getDataType(); + IRBasicType* basicType = as<IRBasicType>(type); + SpvOp opCode = _specConstantOpcodeConvert(inst->getOp(), basicType); + if (opCode == SpvOpUndef) + { + String e = "Unhandled inst in spirv-emit:\n" + + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); + SLANG_UNIMPLEMENTED_X(e.getBuffer()); + } + + Array<SpvInst*, 3> operands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + SpvInst* spvInst = emitSpecializationConstantOp(operand); + operands.add(spvInst); + } + + auto resultType = inst->getFullType(); + return emitInst( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpSpecConstantOp, + resultType, + kResultID, + opCode, + operands); + } + /// Emit a global parameter definition. SpvInst* emitGlobalParam(IRGlobalParam* param) { @@ -7197,117 +7402,13 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); - bool isFloatingPoint = false; - bool isBool = false; - switch (basicType->getBaseType()) - { - case BaseType::Float: - case BaseType::Double: - case BaseType::Half: - isFloatingPoint = true; - break; - case BaseType::Bool: - isBool = true; - break; - default: - break; - } - SpvOp opCode = SpvOpUndef; - bool isSigned = isSignedType(basicType); - switch (op) - { - case kIROp_Add: - opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; - break; - case kIROp_Sub: - opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub; - break; - case kIROp_Mul: - opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul; - break; - case kIROp_Div: - opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv; - break; - case kIROp_IRem: - opCode = isSigned ? SpvOpSRem : SpvOpUMod; - break; - case kIROp_FRem: - opCode = SpvOpFRem; - break; - case kIROp_Less: - opCode = isFloatingPoint ? SpvOpFOrdLessThan - : isSigned ? SpvOpSLessThan - : SpvOpULessThan; - break; - case kIROp_Leq: - opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual - : isSigned ? SpvOpSLessThanEqual - : SpvOpULessThanEqual; - break; - case kIROp_Eql: - opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; - break; - case kIROp_Neq: - opCode = isFloatingPoint ? SpvOpFUnordNotEqual - : isBool ? SpvOpLogicalNotEqual - : SpvOpINotEqual; - break; - case kIROp_Geq: - opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual - : isSigned ? SpvOpSGreaterThanEqual - : SpvOpUGreaterThanEqual; - break; - case kIROp_Greater: - opCode = isFloatingPoint ? SpvOpFOrdGreaterThan - : isSigned ? SpvOpSGreaterThan - : SpvOpUGreaterThan; - break; - case kIROp_Neg: - opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; - break; - case kIROp_And: - opCode = SpvOpLogicalAnd; - break; - case kIROp_Or: - opCode = SpvOpLogicalOr; - break; - case kIROp_Not: - opCode = SpvOpLogicalNot; - break; - case kIROp_BitAnd: - if (isBool) - opCode = SpvOpLogicalAnd; - else - opCode = SpvOpBitwiseAnd; - break; - case kIROp_BitOr: - if (isBool) - opCode = SpvOpLogicalOr; - else - opCode = SpvOpBitwiseOr; - break; - case kIROp_BitXor: - if (isBool) - opCode = SpvOpLogicalNotEqual; - else - opCode = SpvOpBitwiseXor; - break; - case kIROp_BitNot: - if (isBool) - opCode = SpvOpLogicalNot; - else - opCode = SpvOpNot; - break; - case kIROp_Rsh: - opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; - break; - case kIROp_Lsh: - opCode = SpvOpShiftLeftLogical; - break; - default: + + SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); + if (opCode == SpvOpUndef) SLANG_ASSERT(!"unknown arithmetic opcode"); - break; - } + + bool isFloatingPoint = (getTypeStyle(basicType->getBaseType()) == kIROp_FloatType); + if (operandCount == 1) { return emitInst(parent, instToRegister, opCode, type, kResultID, operands); @@ -7846,7 +7947,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex emitDebugType(arrayType->getElementType()), sizedArrayType ? builder.getIntValue( builder.getUIntType(), - getIntVal(sizedArrayType->getElementCount())) + getArraySizeVal(sizedArrayType->getElementCount())) : builder.getIntValue(builder.getUIntType(), 0)); } else if (auto vectorType = as<IRVectorType>(type)) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 5a62c8063..f863858e4 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -95,6 +95,7 @@ INST(Nop, nop, 0, 0) /* Rate */ INST(ConstExprRate, ConstExpr, 0, HOISTABLE) + INST(SpecConstRate, SpecConst, 0, HOISTABLE) INST(GroupSharedRate, GroupShared, 0, HOISTABLE) INST(ActualGlobalRate, ActualGlobalRate, 0, HOISTABLE) INST_RANGE(Rate, ConstExprRate, GroupSharedRate) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 268929fb9..3280dc35c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3938,6 +3938,7 @@ public: IRConstExprRate* getConstExprRate(); IRGroupSharedRate* getGroupSharedRate(); IRActualGlobalRate* getActualGlobalRate(); + IRSpecConstRate* getSpecConstRate(); IRRateQualifiedType* getRateQualifiedType(IRRate* rate, IRType* dataType); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 6f0e22a57..1294b400d 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -309,7 +309,7 @@ struct LoweredElementTypeContext builder.emitBlock(); auto packedParam = builder.emitParam(refStructType); auto packedArray = builder.emitFieldAddress(packedParam, dataKey); - auto count = getIntVal(arrayType->getElementCount()); + auto count = getArraySizeVal(arrayType->getElementCount()); IRInst* result = nullptr; if (count <= kMaxArraySizeToUnroll) { @@ -374,7 +374,7 @@ struct LoweredElementTypeContext builder.emitBlock(); auto outParam = builder.emitParam(outLoweredType); auto originalParam = builder.emitParam(arrayType); - auto count = getIntVal(arrayType->getElementCount()); + auto count = getArraySizeVal(arrayType->getElementCount()); auto destArray = builder.emitFieldAddress(outParam, arrayStructKey); if (count <= kMaxArraySizeToUnroll) { @@ -602,7 +602,8 @@ struct LoweredElementTypeContext StringBuilder nameSB; nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; getTypeNameHint(nameSB, arrayType->getElementType()); - nameSB << getIntVal(arrayType->getElementCount()); + nameSB << getArraySizeVal(arrayType->getElementCount()); + builder.addNameHintDecoration( loweredType, nameSB.produceString().getUnownedSlice()); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 5aae53747..9d8773237 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2250,4 +2250,27 @@ bool isFirstBlock(IRInst* inst) return block->getParent()->getFirstBlock() == block; } +bool isSpecConstRateType(IRType* type) +{ + if (auto rateQualifiedType = as<IRRateQualifiedType>(type)) + { + if (as<IRSpecConstRate>(rateQualifiedType->getRate())) + { + return true; + } + } + return false; +} +void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst) +{ + IRInst* moduleInst = builder->getModule()->getModuleInst(); + UInt operandCount = inst->getOperandCount(); + for (UInt ii = 0; ii < operandCount; ++ii) + { + auto operand = inst->getOperand(ii); + if (operand->parent != moduleInst) + hoistInstAndOperandsToGlobal(builder, operand); + } + inst->insertAt(IRInsertLoc::atStart(moduleInst)); +} } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index b111f8abf..900e22c76 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -388,6 +388,9 @@ void legalizeDefUse(IRGlobalValueWithCode* func); UnownedStringSlice getMangledName(IRInst* inst); bool isFirstBlock(IRInst* inst); + +bool isSpecConstRateType(IRType* type); +void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst); } // namespace Slang #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index fb7d752d5..9c4cb98c0 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -314,6 +314,23 @@ IRIntegerValue getIntVal(IRInst* inst) } } +IRIntegerValue getArraySizeVal(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + return static_cast<IRConstant*>(inst)->value.intVal; + break; + default: + // Treat specialization constant array as the unsized array here. + if (isSpecConstRateType(inst->getFullType())) + return kUnsizedArrayMagicLength; + + SLANG_UNEXPECTED("needed a known integer value"); + UNREACHABLE_RETURN(0); + } +} + // IRCapabilitySet CapabilitySet IRCapabilitySet::getCaps() @@ -3194,6 +3211,10 @@ IRActualGlobalRate* IRBuilder::getActualGlobalRate() { return (IRActualGlobalRate*)getType(kIROp_ActualGlobalRate); } +IRSpecConstRate* IRBuilder::getSpecConstRate() +{ + return (IRSpecConstRate*)getType(kIROp_SpecConstRate); +} IRRateQualifiedType* IRBuilder::getRateQualifiedType(IRRate* rate, IRType* dataType) { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 91c2f018a..461ed567a 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1162,6 +1162,11 @@ struct IRBoolLit : IRConstant // if it has one, and assert-fail otherwise. IRIntegerValue getIntVal(IRInst* inst); +// If it's a specialization constant sized array or unsized array, returns +// kUnsizedArrayMagicLength if it's an unsized array. Otherwise just returns +// the actual size. +IRIntegerValue getArraySizeVal(IRInst* inst); + struct IRStringLit : IRConstant { @@ -1644,6 +1649,7 @@ struct IRAtomicType : IRType SIMPLE_IR_PARENT_TYPE(Rate, Type) SIMPLE_IR_TYPE(ConstExprRate, Rate) +SIMPLE_IR_TYPE(SpecConstRate, Rate) SIMPLE_IR_TYPE(GroupSharedRate, Rate) SIMPLE_IR_TYPE(ActualGlobalRate, Rate) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a21c93f06..af285f221 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -492,6 +492,8 @@ struct SharedIRGenContext Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst; Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst; + Dictionary<IntVal*, IRInst*> mapSpecConstValToIRInst; + void setGlobalValue(Decl* decl, LoweredValInfo value) { globalEnv.mapDeclToValue[decl] = value; @@ -1552,6 +1554,14 @@ static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase // +static void maybePropagateRate(IRBuilder* builder, IRType* rateQulifiedType, IRInst* inst) +{ + if (isSpecConstRateType(rateQulifiedType)) + { + inst->setFullType( + builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType())); + } +} struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo> { @@ -1565,7 +1575,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower UNREACHABLE_RETURN(LoweredValInfo()); } - LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val) + LoweredValInfo visitDeclRefIntVal(DeclRefIntVal* val) { return emitDeclRef( context, @@ -1577,27 +1587,35 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower { TryClauseEnvironment tryEnv; List<IRInst*> args; + IRType* specConstRateType = nullptr; for (auto arg : val->getArgs()) { auto loweredArg = lowerVal(context, arg); args.add(loweredArg.val); + if (!specConstRateType && isSpecConstRateType(loweredArg.val->getFullType())) + specConstRateType = loweredArg.val->getFullType(); } auto funcType = lowerType(context, val->getFuncType()); - return emitCallToDeclRef( + auto resVal = emitCallToDeclRef( context, as<IRFuncType>(funcType)->getResultType(), val->getFuncDeclRef(), funcType, args, tryEnv); + maybePropagateRate(getBuilder(), specConstRateType, resVal.val); + return resVal; } LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val) { auto baseVal = lowerVal(context, val->getBase()); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); auto type = lowerType(context, val->getType()); - return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val)); + auto resVal = LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val)); + maybePropagateRate(getBuilder(), baseVal.val->getFullType(), resVal.val); + return resVal; } LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val) @@ -1625,8 +1643,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower { termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal); } + maybePropagateRate(getBuilder(), factorVal->getFullType(), termVal); } resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal); + maybePropagateRate(getBuilder(), termVal->getFullType(), resultVal); } return LoweredValInfo::simple(resultVal); } @@ -2056,7 +2076,18 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto elementType = lowerType(context, type->getElementType()); if (!type->isUnsized()) { - auto elementCount = lowerSimpleVal(context, type->getElementCount()); + IRInst* elementCount = nullptr; + auto sizeVal = type->getElementCount(); + auto sharedContext = context->shared; + if (!sharedContext->mapSpecConstValToIRInst.tryGetValue(sizeVal, elementCount)) + { + elementCount = lowerSimpleVal(context, sizeVal); + if (isSpecConstRateType(elementCount->getFullType())) + { + sharedContext->mapSpecConstValToIRInst.add(sizeVal, elementCount); + hoistInstAndOperandsToGlobal(getBuilder(), elementCount); + } + } return getBuilder()->getArrayType(elementType, elementCount); } else @@ -2446,6 +2477,13 @@ void maybeSetRate(IRGenContext* context, IRInst* inst, Decl* decl) inst->setFullType( builder->getRateQualifiedType(builder->getActualGlobalRate(), inst->getFullType())); } + else if ( + decl->hasModifier<SpecializationConstantAttribute>() || + decl->hasModifier<VkConstantIdAttribute>()) + { + inst->setFullType( + builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType())); + } } static String getNameForNameHint(IRGenContext* context, Decl* decl) @@ -11846,9 +11884,15 @@ RefPtr<IRModule> generateIRForTranslationUnit( } #if 0 + if (compileRequest->optionSet.shouldDumpIR()) { DiagnosticSinkWriter writer(compileRequest->getSink()); - dumpIR(module, &writer, "GENERATED"); + dumpIR( + module, + compileRequest->m_irDumpOptions, + "GENERATED", + compileRequest->getSourceManager(), + &writer); } #endif diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index f08ffd75d..056c7accb 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -323,7 +323,7 @@ void emitVal(ManglingContext* context, Val* val) // to mangle in the constraints even when // the whole thing is specialized... } - else if (auto genericParamIntVal = dynamicCast<GenericParamIntVal>(val)) + else if (auto genericParamIntVal = dynamicCast<DeclRefIntVal>(val)) { // TODO: we shouldn't be including the names of generic parameters // anywhere in mangled names, since changing parameter names diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 065c2c3f6..258266da5 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -3906,7 +3906,7 @@ SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal( auto valueParamDeclRef = convert(valueParam); - Val* valResult = astBuilder->getOrCreate<GenericParamIntVal>( + Val* valResult = astBuilder->getOrCreate<DeclRefIntVal>( valueParamDeclRef.substitute( astBuilder, as<GenericValueParamDecl>(valueParamDeclRef.getDecl())->getType()), diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 9dea0f167..5f2cf4ddf 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -2336,7 +2336,7 @@ static LayoutSize GetElementCount(IntVal* val) return LayoutSize::infinite(); return LayoutSize(LayoutSize::RawValue(constantVal->getValue())); } - else if (const auto varRefVal = as<GenericParamIntVal>(val)) + else if (const auto varRefVal = as<DeclRefIntVal>(val)) { // TODO: We want to treat the case where the number of // elements in an array depends on a generic parameter @@ -2352,6 +2352,10 @@ static LayoutSize GetElementCount(IntVal* val) { return 0; } + else if (as<FuncCallIntVal>(val)) + { + return 0; + } SLANG_UNEXPECTED("unhandled integer literal kind"); UNREACHABLE_RETURN(LayoutSize(0)); } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 602446cda..67d13c34b 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -6079,7 +6079,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor { collectReferencedModules(type); } - else if (auto declRefVal = as<GenericParamIntVal>(val)) + else if (auto declRefVal = as<DeclRefIntVal>(val)) { collectReferencedModules(declRefVal->getDeclRef()); } |
