diff options
25 files changed, 705 insertions, 58 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index f3ab38582..1e1ef061e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -719,19 +719,7 @@ struct Ptr __subscript(int index) -> T { - [__unsafeForceInlineEarly] - get - { - return __load(__getElementPtr(this, index)); - } - - [__unsafeForceInlineEarly] - set(T newValue) - { - __store(__getElementPtr(this, index), newValue); - } - - __intrinsic_op($(kIROp_GetElementPtr)) + __intrinsic_op($(kIROp_GetOffsetPtr)) ref; } }; @@ -748,6 +736,12 @@ Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int index); __intrinsic_op($(kIROp_GetElementPtr)) Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int64_t index); +__intrinsic_op($(kIROp_GetOffsetPtr)) +Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int index); + +__intrinsic_op($(kIROp_GetOffsetPtr)) +Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int64_t index); + __generic<T> __intrinsic_op($(kIROp_Less)) bool operator<(Ptr<T> p1, Ptr<T> p2); @@ -1543,14 +1537,14 @@ __intrinsic_op(0) __prefix Ptr<T> operator&(__ref T value); __generic<T> -__intrinsic_op($(kIROp_GetElementPtr)) +__intrinsic_op($(kIROp_GetOffsetPtr)) Ptr<T> operator+(Ptr<T> value, int64_t offset); __generic<T> [__unsafeForceInlineEarly] Ptr<T> operator-(Ptr<T> value, int64_t offset) { - return __getElementPtr(value, -offset); + return __getOffsetPtr(value, -offset); } __generic<T : IArithmetic> diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 8183c2030..156ecc194 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -13164,12 +13164,6 @@ struct ConstBufferPointer } } - __subscript(int index) -> T - { - [ForceInline] - get {return ConstBufferPointer<T>.fromUInt(toUInt() + __naturalStrideOf<T>() * index).get(); } - } - __glsl_version(450) __glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64) __glsl_extension(GL_EXT_buffer_reference) @@ -13221,4 +13215,10 @@ struct ConstBufferPointer }; } } + + __subscript(int index)->T + { + [ForceInline] + get { return ConstBufferPointer<T>.fromUInt(toUInt() + __naturalStrideOf<T>() * index).get(); } + } } diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index e2d0638e0..fc6f321e3 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -52,6 +52,10 @@ struct ASTIterator { iterator->maybeDispatchCallback(expr); } + void visitOpenRefExpr(OpenRefExpr* expr) + { + dispatchIfNotNull(expr->innerExpr); + } void visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) { iterator->maybeDispatchCallback(expr); diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 882e26078..c1984910c 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -109,6 +109,7 @@ namespace Slang kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32, kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81, + kConversionCost_MutablePtrToConstPtr = 20, // Conversions based on explicit sub-typing relationships are the cheapest // diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 1d2ebc566..d47e3a496 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -535,7 +535,8 @@ class PtrType : public PtrTypeBase SLANG_AST_CLASS(PtrType) }; -// A GPU pointer type that for general readonly memory access. +// A GPU pointer type into global memory. + class ConstBufferPointerType : public PtrTypeBase { SLANG_AST_CLASS(ConstBufferPointerType) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f9adcc91a..2f4906826 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -479,7 +479,10 @@ namespace Slang derefExpr->base = base; derefExpr->type = QualType(elementType); - derefExpr->type.isLeftValue = base->type.isLeftValue; + if (as<PtrType>(base->type)) + derefExpr->type.isLeftValue = true; + else + derefExpr->type.isLeftValue = base->type.isLeftValue; return derefExpr; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 62bc73c90..9b599ae2e 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -782,6 +782,8 @@ DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") DIAGNOSTIC(58001, Error, entryPointMustReturnVoidWhenGlobalOutputPresent, "entry point must return 'void' when global output variables are present.") DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage Buffer Object contents, unsized arrays as a final parameter must be the only parameter") +DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.") + // // 8xxxx - Issues specific to a particular library/technology/platform/etc. // diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 05c525965..1c01478ed 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2055,6 +2055,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO case kIROp_MatrixReshape: case kIROp_CastPtrToInt: case kIROp_CastIntToPtr: + case kIROp_PtrCast: { // Simple constructor call auto prec = getInfo(EmitOp::Prefix); @@ -2345,6 +2346,15 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_writer->emit(".detach()"); break; } + case kIROp_GetOffsetPtr: + { + auto prec = getInfo(EmitOp::Add); + needClose = maybeEmitParens(outerPrec, prec); + emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); + m_writer->emit(" + "); + emitOperand(inst->getOperand(1), rightSide(prec, outerPrec)); + break; + } case kIROp_GetElement: case kIROp_GetElementPtr: case kIROp_ImageSubscript: @@ -4097,7 +4107,8 @@ void CLikeSourceEmitter::ensureGlobalInst(ComputeEmitActionsContext* ctx, IRInst } if (as<IRBasicType>(inst)) return; - + if (as<IRPtrLit>(inst)) + return; // Certain inst ops will always emit as definition. switch (inst->getOp()) { diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 891372fa6..32f47b3ef 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -362,6 +362,20 @@ SpvInst* emitOpTypeStruct(IRInst* inst, const Ts& member0TypeMember1TypeEtc) ); } +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeForwardPointer +template<typename T> +SpvInst* emitOpTypeForwardPointer(const T& type, SpvStorageClass storageClass) +{ + static_assert(isSingular<T>); + return emitInst( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + nullptr, + SpvOpTypeForwardPointer, + type, + storageClass + ); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypePointer template<typename T> SpvInst* emitOpTypePointer(IRInst* inst, SpvStorageClass storageClass, const T& type) @@ -623,6 +637,23 @@ SpvInst* emitOpAccessChain( return emitInst(parent, inst, SpvOpAccessChain, idResultType, kResultID, base, indexes); } + +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpPtrAccessChain +template<typename T1, typename T2, typename T3> +SpvInst* emitOpPtrAccessChain( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& base, + const T3& element +) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + return emitInst(parent, inst, SpvOpPtrAccessChain, idResultType, kResultID, base, element); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate template<typename T> SpvInst* emitOpDecorate( diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 6e1e58755..d5d00e417 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -135,7 +135,6 @@ public: /// Dump all children, recursively, to a flattened list of SPIR-V words void dumpTo(List<SpvWord>& ioWords); -private: /// The first child, if any. SpvInst* m_firstChild = nullptr; @@ -145,7 +144,7 @@ private: /// while if it is non-empty it points to the `nextSibling` field /// of the last instruction. /// - SpvInst** m_link = &m_firstChild; + SpvInst* m_lastChild = nullptr; }; // A SPIR-V instruction is then (in the general case) a potential @@ -198,9 +197,13 @@ struct SpvInst : SpvInstParent // We will store the instructions in a given `SpvInstParent` // using an intrusive linked list. + SpvInstParent* parent = nullptr; + /// The next instruction in the same `SpvInstParent` SpvInst* nextSibling = nullptr; + SpvInst* prevSibling = nullptr; + /// The result <id> produced by this instruction, or zero if it has no result. SpvWord id = 0; @@ -235,6 +238,43 @@ struct SpvInst : SpvInstParent // SpvInstParent::dumpTo(ioWords); } + + void removeFromParent() + { + auto oldParent = parent; + + // If we don't currently have a parent, then + // we are doing fine. + if (!oldParent) + return; + + auto pp = prevSibling; + auto nn = nextSibling; + + if (pp) + { + SLANG_ASSERT(pp->parent == oldParent); + pp->nextSibling = nn; + } + else + { + oldParent->m_firstChild = nn; + } + + if (nn) + { + SLANG_ASSERT(nn->parent == oldParent); + nn->prevSibling = pp; + } + else + { + oldParent->m_lastChild = pp; + } + + prevSibling = nullptr; + nextSibling = nullptr; + parent = nullptr; + } }; /// A logical section of a SPIR-V module @@ -248,15 +288,22 @@ struct SpvLogicalSection : SpvInstParent void SpvInstParent::addInst(SpvInst* inst) { SLANG_ASSERT(inst); + SLANG_ASSERT(!inst->nextSibling); + + if (m_firstChild == nullptr) + { + m_firstChild = m_lastChild = inst; + return; + } // The user shouldn't be trying to add multiple instructions at once. // If they really want that then they probably wanted to give `inst` // some children. // - SLANG_ASSERT(!inst->nextSibling); - - *m_link = inst; - m_link = &inst->nextSibling; + m_lastChild->nextSibling = inst; + inst->prevSibling = m_lastChild; + inst->parent = this; + m_lastChild = inst; } void SpvInstParent::dumpTo(List<SpvWord>& ioWords) @@ -429,6 +476,11 @@ struct SPIRVEmitContext /// The next destination `<id>` to allocate. SpvWord m_nextID = 1; + OrderedHashSet<IRPtrTypeBase*> m_forwardDeclaredPointers; + + // A hash set to prevent redecorating the same spv inst. + HashSet<SpvId> m_decoratedSpvInsts; + SpvAddressingModel m_addressingMode = SpvAddressingModelLogical; // We will store the logical sections of the SPIR-V module @@ -1244,6 +1296,17 @@ struct SPIRVEmitContext return m_targetRequest->getHLSLToVulkanLayoutOptions()->shouldEmitSPIRVReflectionInfo(); } + void requireVariablePointers() + { + if (m_addressingMode == SpvAddressingModelPhysicalStorageBuffer64) + return; + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers")); + requireSPIRVCapability(SpvCapabilityVariablePointers); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_physical_storage_buffer")); + requireSPIRVCapability(SpvCapabilityPhysicalStorageBufferAddresses); + m_addressingMode = SpvAddressingModelPhysicalStorageBuffer64; + } + // Next, let's look at emitting some of the instructions // that can occur at global scope. @@ -1312,11 +1375,41 @@ struct SPIRVEmitContext storageClass = (SpvStorageClass)ptrType->getAddressSpace(); if (storageClass == SpvStorageClassStorageBuffer) ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class")); - return emitOpTypePointer( + if (storageClass == SpvStorageClassPhysicalStorageBuffer) + { + requireVariablePointers(); + } + auto valueType = ptrType->getValueType(); + // If we haven't emitted the inner type yet, we need to emit a forward declaration. + bool useForwardDeclaration = (!m_mapIRInstToSpvInst.containsKey(valueType) + && as<IRStructType>(valueType) + && storageClass == SpvStorageClassPhysicalStorageBuffer); + auto resultSpvType = emitOpTypePointer( inst, storageClass, - inst->getOperand(0) + useForwardDeclaration? getIRInstSpvID(valueType) : getID(ensureInst(valueType)) ); + if (useForwardDeclaration) + { + // After everything has been emitted, we will move the pointer definition to the end + // of the Types & Constants section. + if (m_forwardDeclaredPointers.add(ptrType)) + emitOpTypeForwardPointer(resultSpvType, storageClass); + } + if (storageClass == SpvStorageClassPhysicalStorageBuffer) + { + if (m_decoratedSpvInsts.add(getID(resultSpvType))) + { + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment(m_targetRequest, ptrType->getValueType(), &sizeAndAlignment); + emitOpDecorateArrayStride( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + resultSpvType, + SpvLiteralInteger::from32((uint32_t)sizeAndAlignment.getStride())); + } + } + return resultSpvType; } case kIROp_ConstantBufferType: SLANG_UNEXPECTED("Constant buffer type remaining in spirv emit"); @@ -1404,11 +1497,7 @@ struct SPIRVEmitContext return emitOpTypeHitObject(inst); case kIROp_HLSLConstBufferPointerType: - ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers")); - requireSPIRVCapability(SpvCapabilityVariablePointers); - ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_physical_storage_buffer")); - requireSPIRVCapability(SpvCapabilityPhysicalStorageBufferAddresses); - m_addressingMode = SpvAddressingModelPhysicalStorageBuffer64; + requireVariablePointers(); return emitOpTypePointer(inst, SpvStorageClassPhysicalStorageBuffer, inst->getOperand(0)); case kIROp_FuncType: @@ -1446,6 +1535,7 @@ struct SPIRVEmitContext case kIROp_IntLit: case kIROp_FloatLit: case kIROp_StringLit: + case kIROp_PtrLit: { return emitLit(inst); } @@ -1978,6 +2068,7 @@ struct SPIRVEmitContext param->getDataType(), storageClass ); + maybeEmitPointerDecoration(varInst, param); if (auto layout = getVarLayout(param)) emitVarLayout(param, varInst, layout); maybeEmitName(varInst, param); @@ -2001,6 +2092,7 @@ struct SPIRVEmitContext globalVar->getDataType(), storageClass ); + maybeEmitPointerDecoration(varInst, globalVar); if(layout) emitVarLayout(globalVar, varInst, layout); maybeEmitName(varInst, globalVar); @@ -2274,6 +2366,8 @@ struct SPIRVEmitContext return emitFieldExtract(parent, as<IRFieldExtract>(inst)); case kIROp_GetElementPtr: return emitGetElementPtr(parent, as<IRGetElementPtr>(inst)); + case kIROp_GetOffsetPtr: + return emitGetOffsetPtr(parent, inst); case kIROp_GetElement: return emitGetElement(parent, as<IRGetElement>(inst)); case kIROp_MakeStruct: @@ -2306,6 +2400,13 @@ struct SPIRVEmitContext return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst)); case kIROp_CastFloatToInt: return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst)); + case kIROp_CastPtrToInt: + return emitCastPtrToInt(parent, inst); + case kIROp_CastPtrToBool: + return emitCastPtrToBool(parent, inst); + case kIROp_CastIntToPtr: + return emitCastIntToPtr(parent, inst); + case kIROp_PtrCast: case kIROp_BitCast: return emitOpBitcast( parent, @@ -3403,10 +3504,27 @@ struct SPIRVEmitContext return nullptr; } + void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst) + { + auto ptrType = as<IRPtrType>(inst->getDataType()); + if (!ptrType) + return; + if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + { + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + varInst, + (as<IRVar>(inst) ? SpvDecorationAliasedPointer : SpvDecorationAliased) + ); + } + } + SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) { auto paramSpvInst = emitOpFunctionParameter(parent, inst, inst->getFullType()); maybeEmitName(paramSpvInst, inst); + maybeEmitPointerDecoration(paramSpvInst, inst); return paramSpvInst; } @@ -3421,6 +3539,7 @@ struct SPIRVEmitContext } auto varSpvInst = emitOpVariable(parent, inst, inst->getFullType(), storageClass); maybeEmitName(varSpvInst, inst); + maybeEmitPointerDecoration(varSpvInst, inst); return varSpvInst; } @@ -3962,6 +4081,11 @@ struct SPIRVEmitContext ); } + SpvInst* emitGetOffsetPtr(SpvInstParent* parent, IRInst* inst) + { + return emitOpPtrAccessChain(parent, inst, inst->getDataType(), inst->getOperand(0), inst->getOperand(1)); + } + SpvInst* emitGetElementPtr(SpvInstParent* parent, IRGetElementPtr* inst) { IRBuilder builder(m_irModule); @@ -4025,12 +4149,32 @@ struct SPIRVEmitContext SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst) { - return emitOpLoad(parent, inst, inst->getDataType(), inst->getPtr()); + auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType()); + if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + { + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment(m_targetRequest, ptrType->getValueType(), &sizeAndAlignment); + return emitOpLoadAligned(parent, inst, inst->getDataType(), inst->getPtr(), SpvLiteralInteger::from32(sizeAndAlignment.alignment)); + } + else + { + return emitOpLoad(parent, inst, inst->getDataType(), inst->getPtr()); + } } SpvInst* emitStore(SpvInstParent* parent, IRStore* inst) { - return emitOpStore(parent, inst, inst->getPtr(), inst->getVal()); + auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType()); + if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + { + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment(m_targetRequest, ptrType->getValueType(), &sizeAndAlignment); + return emitOpStoreAligned(parent, inst, inst->getPtr(), inst->getVal(), SpvLiteralInteger::from32(sizeAndAlignment.alignment)); + } + else + { + return emitOpStore(parent, inst, inst->getPtr(), inst->getVal()); + } } SpvInst* emitSwizzledStore(SpvInstParent* parent, IRSwizzledStore* inst) @@ -4322,6 +4466,23 @@ struct SPIRVEmitContext : emitOpConvertFToU(parent, inst, toTypeV, inst->getOperand(0)); } + SpvInst* emitCastPtrToInt(SpvInstParent* parent, IRInst* inst) + { + return emitInst(parent, inst, SpvOpConvertPtrToU, inst->getFullType(), kResultID, inst->getOperand(0)); + } + + SpvInst* emitCastPtrToBool(SpvInstParent* parent, IRInst* inst) + { + IRBuilder builder(inst); + auto uintVal = emitInst(parent, nullptr, SpvOpConvertPtrToU, builder.getUInt64Type(), kResultID, inst->getOperand(0)); + return emitOpINotEqual(parent, inst, kResultID, uintVal, builder.getIntValue(builder.getUInt64Type(), 0)); + } + + SpvInst* emitCastIntToPtr(SpvInstParent* parent, IRInst* inst) + { + return emitInst(parent, inst, SpvOpConvertUToPtr, inst->getFullType(), kResultID, inst->getOperand(0)); + } + template<typename T, typename Ts> SpvInst* emitCompositeConstruct( SpvInstParent* parent, @@ -5124,6 +5285,16 @@ SlangResult emitSPIRVFromIR( { context.ensureInst(irEntryPoint); } + + // Move forward delcared pointers to the end. + for (auto ptrType : context.m_forwardDeclaredPointers) + { + auto spvPtrType = context.m_mapIRInstToSpvInst[ptrType]; + auto parent = spvPtrType->parent; + spvPtrType->removeFromParent(); + parent->addInst(spvPtrType); + } + context.emitFrontMatter(); context.emitPhysicalLayout(); diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 63ca32650..7a93a312c 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -112,6 +112,7 @@ bool opCanBeConstExpr(IROp op) case kIROp_CastIntToPtr: case kIROp_CastPtrToInt: case kIROp_CastPtrToBool: + case kIROp_PtrCast: case kIROp_Reinterpret: case kIROp_BitCast: case kIROp_MakeTuple: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index e183058ac..0c962b7a4 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -386,6 +386,8 @@ INST(FieldAddress, get_field_addr, 2, 0) INST(GetElement, getElement, 2, 0) INST(GetElementPtr, getElementPtr, 2, 0) +// Pointer offset: computes pBase + offset_in_elements +INST(GetOffsetPtr, getOffsetPtr, 2, 0) INST(GetAddr, getAddr, 1, 0) // Get an unowned NativeString from a String. @@ -1011,6 +1013,7 @@ INST(CastPtrToBool, CastPtrToBool, 1, 0) INST(CastPtrToInt, CastPtrToInt, 1, 0) INST(CastIntToPtr, CastIntToPtr, 1, 0) INST(CastToVoid, castToVoid, 1, 0) +INST(PtrCast, PtrCast, 1, 0) INST(SizeOf, sizeOf, 1, 0) INST(AlignOf, alignOf, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 6e3821c18..82d891459 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2216,7 +2216,6 @@ struct IRFieldAddress : IRInst IRInst* getBase() { return base.get(); } IRInst* getField() { return field.get(); } IR_LEAF_ISA(FieldAddress) - }; struct IRGetElement : IRInst @@ -4065,6 +4064,8 @@ public: IRInst* sizedType); IRInst* emitCastPtrToBool(IRInst* val); + IRInst* emitCastPtrToInt(IRInst* val); + IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val); IRGlobalConstant* emitGlobalConstant( IRType* type); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 36769cc34..eb0068657 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -276,7 +276,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) { IRConstant* c = (IRConstant*)originalValue; SLANG_RELEASE_ASSERT(c->value.ptrVal == nullptr); - return builder->getNullVoidPtrValue(); + return builder->getNullPtrValue(cloneType(this, c->getFullType())); } break; diff --git a/source/slang/slang-ir-liveness.cpp b/source/slang/slang-ir-liveness.cpp index 9cd4462af..28cd64a08 100644 --- a/source/slang/slang-ir-liveness.cpp +++ b/source/slang/slang-ir-liveness.cpp @@ -1029,6 +1029,7 @@ bool LivenessContext::_isAccessTerminator(IRTerminatorInst* terminator) case kIROp_CastIntToPtr: case kIROp_CastPtrToInt: case kIROp_CastPtrToBool: + case kIROp_PtrCast: val = val->getOperand(0); break; } diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index fb67c6842..39a137490 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -581,7 +581,7 @@ struct PeepholeContext : InstPassBase auto ptr = inst->getOperand(0); IRBuilder builder(module); builder.setInsertBefore(inst); - auto neq = builder.emitNeq(ptr, builder.getNullVoidPtrValue()); + auto neq = builder.emitNeq(ptr, builder.getNullPtrValue(ptr->getDataType())); inst->replaceUsesWith(neq); maybeRemoveOldInst(inst); changed = true; diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index b82daa9a4..60001661c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1995,11 +1995,6 @@ struct SpecializationContext { return 2; } - else if (auto ptrType = as<IRPtrTypeBase>(type)) - { - type = ptrType->getValueType(); - goto top; - } else if (auto ptrLikeType = as<IRPointerLikeType>(type)) { type = ptrLikeType->getElementType(); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 1675fe279..474ebc71c 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -741,6 +741,66 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return result; } + void processVar(IRInst* inst) + { + auto oldPtrType = as<IRPtrType>(inst->getDataType()); + if (!oldPtrType->hasAddressSpace()) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newPtrType = builder.getPtrType( + oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction); + inst->setFullType(newPtrType); + addUsersToWorkList(inst); + } + } + + void processParam(IRInst* inst) + { + auto block = getBlock(inst); + auto func = getParentFunc(block); + if (!block || !func) + return; + auto oldPtrType = as<IRPtrType>(inst->getDataType()); + if (!oldPtrType) + return; + if (!oldPtrType->hasAddressSpace()) + { + SpvStorageClass addressSpace = (SpvStorageClass)-1; + + if (block == func->getFirstBlock()) + { + // A pointer typed function parameter should always be in the storage buffer address space. + addressSpace = SpvStorageClassPhysicalStorageBuffer; + } + else + { + // The address space of a phi inst should always be the same as arguments. + auto args = getPhiArgs(inst); + for (auto arg : args) + { + auto argPtrType = as<IRPtrType>(arg->getDataType()); + if (argPtrType->hasAddressSpace()) + { + if (addressSpace == (SpvStorageClass)-1) + addressSpace = (SpvStorageClass)argPtrType->getAddressSpace(); + else if (addressSpace != argPtrType->getAddressSpace()) + m_sharedContext->m_sink->diagnose(inst, Diagnostics::inconsistentPointerAddressSpace, inst); + } + } + } + if (addressSpace != (SpvStorageClass)-1) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newPtrType = builder.getPtrType( + oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassPhysicalStorageBuffer); + inst->setFullType(newPtrType); + addUsersToWorkList(inst); + } + } + } + void processGlobalVar(IRInst* inst) { auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType()); @@ -844,6 +904,16 @@ struct SPIRVLegalizationContext : public SourceEmitterBase for (UInt i = 0; i < inst->getArgCount(); i++) { auto arg = inst->getArg(i); + auto paramType = funcType->getParamType(i); + if (as<IRPtrType>(paramType)) + { + // If the parameter has an explicit pointer type, + // then we know the user is using the variable pointer + // capability to pass a true pointer. + // In this case we should not rewrite the call. + newArgs.add(arg); + continue; + } auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); if (!as<IRPtrTypeBase>(arg->getDataType())) { @@ -898,7 +968,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount()); if (writeBacks.getCount()) { - auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs); + auto newCall = builder.emitCallInst( + translateToStorageBufferPointer(inst->getFullType()), + inst->getCallee(), + newArgs); for (auto wb : writeBacks) { auto newVal = builder.emitLoad(wb.tempVar); @@ -908,6 +981,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase inst->removeAndDeallocate(); addUsersToWorkList(newCall); } + else + { + translatePtrResultType(inst); + } } Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar; @@ -989,6 +1066,28 @@ struct SPIRVLegalizationContext : public SourceEmitterBase processGetElementPtrImpl(gepInst, gepInst->getBase(), gepInst->getIndex()); } + void processGetOffsetPtr(IRInst* offsetPtrInst) + { + auto ptrOperandType = as<IRPtrType>(offsetPtrInst->getOperand(0)->getDataType()); + if (!ptrOperandType) + return; + if (!ptrOperandType->hasAddressSpace()) + return; + auto resultPtrType = as<IRPtrType>(offsetPtrInst->getDataType()); + if (!resultPtrType) + return; + if (resultPtrType->getAddressSpace() != ptrOperandType->getAddressSpace()) + { + IRBuilder builder(offsetPtrInst); + builder.setInsertBefore(offsetPtrInst); + auto newResultType = builder.getPtrType(resultPtrType->getOp(), + resultPtrType->getValueType(), + ptrOperandType->getAddressSpace()); + auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType); + addUsersToWorkList(newInst); + } + } + void processStructuredBufferLoad(IRInst* loadInst) { auto sb = loadInst->getOperand(0); @@ -1060,13 +1159,16 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (!ptrType->hasAddressSpace()) return; auto oldResultType = as<IRPtrTypeBase>(inst->getDataType()); - if (oldResultType->getAddressSpace() != ptrType->getAddressSpace()) + auto oldValueType = oldResultType->getValueType(); + auto newValueType = translateToStorageBufferPointer(oldValueType); + + if (oldValueType != newValueType || oldResultType->getAddressSpace() != ptrType->getAddressSpace()) { IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( oldResultType->getOp(), - oldResultType->getValueType(), + newValueType, ptrType->getAddressSpace()); auto newInst = builder.emitFieldAddress(newPtrType, inst->getBase(), inst->getField()); @@ -1077,6 +1179,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + void processFieldExtract(IRFieldExtract* inst) + { + auto ptrType = as<IRPtrType>(inst->getDataType()); + if (!ptrType) + return; + auto newPtrType = translateToStorageBufferPointer(ptrType); + if (newPtrType == ptrType) + return; + IRBuilder builder(inst); + auto newInst = builder.replaceOperand(&inst->typeUse, newPtrType); + addUsersToWorkList(newInst); + } + void duplicateMergeBlockIfNeeded(IRUse* breakBlockUse) { auto breakBlock = as<IRBlock>(breakBlockUse->get()); @@ -1106,7 +1221,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void processLoop(IRLoop* loop) { - // 2.11.1. Rules for Structured Control-flow Declarations // Structured control flow declarations must satisfy the following // rules: @@ -1186,6 +1300,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // Insert a new continue block at the end of the loop const auto newContinueBlock = builder.emitBlock(); + addToWorkList(newContinueBlock); + newContinueBlock->insertBefore(loop->getBreakBlock()); // This block simply branches to the loop header, forwarding @@ -1204,10 +1320,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase loop->block.set(t); // Branch to the target in our new continue block - builder.emitBranch(t, ps.getCount(), ps.getBuffer()); + auto branch = builder.emitBranch(t, ps.getCount(), ps.getBuffer()); + addToWorkList(branch); } } duplicateMergeBlockIfNeeded(&loop->breakBlock); + addToWorkList(loop->getTargetBlock()); } void processIfElse(IRIfElse* inst) @@ -1223,6 +1341,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newBlock = builder.emitBlock(); builder.emitBranch(inst->getAfterBlock()); inst->trueBlock.set(newBlock); + addToWorkList(newBlock); } if (inst->getFalseBlock() == inst->getAfterBlock()) { @@ -1230,6 +1349,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newBlock = builder.emitBlock(); builder.emitBranch(inst->getAfterBlock()); inst->falseBlock.set(newBlock); + addToWorkList(newBlock); } } @@ -1246,6 +1366,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newBlock = builder.emitBlock(); builder.emitBranch(inst->getBreakLabel()); inst->defaultLabel.set(newBlock); + addToWorkList(newBlock); } for (UInt i = 0; i < inst->getCaseCount(); i++) { @@ -1255,6 +1376,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newBlock = builder.emitBlock(); builder.emitBranch(inst->getBreakLabel()); inst->getCaseLabelUse(i)->set(newBlock); + addToWorkList(newBlock); } } } @@ -1386,6 +1508,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_FieldAddress: case kIROp_GetElement: case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: case kIROp_UpdateElement: case kIROp_MakeTuple: case kIROp_GetTupleElement: @@ -1407,6 +1530,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: case kIROp_CastIntToPtr: + case kIROp_PtrCast: case kIROp_CastPtrToBool: case kIROp_CastPtrToInt: case kIROp_BitAnd: @@ -1467,7 +1591,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase cloneEnv.mapOldValToNew[inst] = result; return result; } - + // If the global value is inlinable, we make all its operands avaialble locally, and then copy it // to the local scope. ShortList<IRInst*> args; @@ -1482,21 +1606,123 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return result; } - void processWorkList() + void processBranch(IRInst* branch) + { + addToWorkList(branch->getOperand(0)); + } + + IRType* translateToStorageBufferPointer(IRType* pointerType) + { + auto ptrType = as<IRPtrType>(pointerType); + if (!ptrType) + return pointerType; + auto oldValueType = ptrType->getValueType(); + auto newValueType = translateToStorageBufferPointer(oldValueType); + if (oldValueType != newValueType || !ptrType->hasAddressSpace()) + { + IRBuilder builder(m_module); + return builder.getPtrType(ptrType->getOp(), newValueType, SpvStorageClassPhysicalStorageBuffer); + } + return ptrType; + } + + void translatePtrResultType(IRInst* inst) + { + auto ptrType = as<IRPtrType>(inst->getDataType()); + auto newPtrType = translateToStorageBufferPointer(ptrType); + if (newPtrType == ptrType) + return; + IRBuilder builder(inst); + auto newInst = builder.replaceOperand(&inst->typeUse, newPtrType); + addUsersToWorkList(newInst); + } + + void processPtrLit(IRInst* inst) { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newPtrType = translateToStorageBufferPointer(as<IRPtrType>(inst->getFullType())); + auto newInst = builder.emitCastIntToPtr(newPtrType, builder.getIntValue(builder.getUInt64Type(), 0)); + inst->replaceUsesWith(newInst); + addUsersToWorkList(newInst); + } + void processPtrCast(IRInst* cast) + { + translatePtrResultType(cast); + } + + void processLoad(IRInst* inst) + { + translatePtrResultType(inst); + } + + void processStructField(IRStructField* field) + { + auto ptrType = as<IRPtrTypeBase>(field->getFieldType()); + if (!ptrType) + return; + if (ptrType->hasAddressSpace()) + return; + IRBuilder builder(field); + auto newPtrType = builder.getPtrType( + ptrType->getOp(), + ptrType->getValueType(), + SpvStorageClassPhysicalStorageBuffer); + field->setFieldType(newPtrType); + } + + void processComparison(IRInst* inst) + { + auto operand0 = inst->getOperand(0); + if (as<IRPtrType>(operand0->getDataType())) + { + // If we are doing pointer comparison, convert the operands into uints first. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto castToUInt = [&](IRInst* operand) + { + if (as<IRPtrLit>(operand)) + return builder.getIntValue(builder.getUInt64Type(), 0); + else + return builder.emitCastPtrToInt(operand); + }; + auto newOperand0 = castToUInt(operand0); + SLANG_ASSERT(as<IRPtrType>(inst->getOperand(1)->getDataType())); + auto newOperand1 = castToUInt(inst->getOperand(1)); + inst = builder.replaceOperand(inst->getOperands(), newOperand0); + inst = builder.replaceOperand(inst->getOperands() + 1, newOperand1); + } + } + + void processWorkList() + { while (workList.getCount() != 0) { IRInst* inst = workList.getLast(); workList.removeLast(); + + // Skip if inst has already been removed. + if (!inst->parent) + continue; + switch (inst->getOp()) { + case kIROp_StructField: + processStructField(as<IRStructField>(inst)); + break; case kIROp_GlobalParam: processGlobalParam(as<IRGlobalParam>(inst)); break; case kIROp_GlobalVar: processGlobalVar(as<IRGlobalVar>(inst)); break; + case kIROp_Var: + processVar(as<IRVar>(inst)); + break; + case kIROp_Param: + processParam(as<IRParam>(inst)); + break; case kIROp_Call: processCall(as<IRCall>(inst)); break; @@ -1506,9 +1732,15 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_GetElementPtr: processGetElementPtr(as<IRGetElementPtr>(inst)); break; + case kIROp_GetOffsetPtr: + processGetOffsetPtr(inst); + break; case kIROp_FieldAddress: processFieldAddress(as<IRFieldAddress>(inst)); break; + case kIROp_FieldExtract: + processFieldExtract(as<IRFieldExtract>(inst)); + break; case kIROp_ImageSubscript: processImageSubscript(as<IRImageSubscript>(inst)); break; @@ -1533,7 +1765,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_Switch: processSwitch(as<IRSwitch>(inst)); break; - + case kIROp_Less: + case kIROp_Leq: + case kIROp_Eql: + case kIROp_Geq: + case kIROp_Greater: + case kIROp_Neq: + processComparison(inst); + break; case kIROp_MakeVectorFromScalar: case kIROp_MakeUInt64: case kIROp_MakeVector: @@ -1551,6 +1790,20 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_MakeOptionalNone: processConstructor(inst); break; + case kIROp_BitCast: + case kIROp_PtrCast: + case kIROp_CastIntToPtr: + processPtrCast(inst); + break; + case kIROp_PtrLit: + processPtrLit(inst); + break; + case kIROp_Load: + processLoad(inst); + break; + case kIROp_unconditionalBranch: + processBranch(inst); + break; case kIROp_SPIRVAsm: processSPIRVAsm(as<IRSPIRVAsm>(inst)); break; @@ -1584,7 +1837,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void processModule() { - convertCompositeTypeParametersToPointers(m_module); + //convertCompositeTypeParametersToPointers(m_module); // Process global params before anything else, so we don't generate inefficient // array marhalling code for array-typed global params. @@ -1631,6 +1884,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase t->replaceUsesWith(lowered); } + // Inline global values that can't represented by SPIRV constant inst + // to their use sites. List<IRUse*> globalInstUsesToInline; for (auto globalInst : m_module->getGlobalInsts()) @@ -1666,6 +1921,63 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (val != use->get()) builder.replaceOperand(use, val); } + + // Some legalization processing may change the function parameter types, + // so we need to update the function types to match that. + updateFunctionTypes(); + } + + void updateFunctionTypes() + { + IRBuilder builder(m_module); + for (auto globalInst : m_module->getGlobalInsts()) + { + auto func = as<IRFunc>(globalInst); + if (!func) + continue; + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + continue; + + builder.setInsertBefore(func); + auto type = func->getDataType(); + auto oldFuncType = as<IRFuncType>(type); + auto resultType = oldFuncType->getResultType(); + List<IRType*> newOperands; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto retInst = as<IRReturn>(inst)) + { + resultType = retInst->getVal()->getFullType(); + break; + } + } + } + for (auto param : firstBlock->getParams()) + { + newOperands.add(param->getDataType()); + } + bool changed = resultType != oldFuncType->getResultType(); + if (!changed) + { + for (UInt i = 0; i < oldFuncType->getParamCount(); i++) + { + if (oldFuncType->getParamType(i) != newOperands[i]) + { + changed = true; + break; + } + } + } + if (changed) + { + builder.setInsertBefore(func); + auto newFuncType = builder.getFuncType(newOperands, resultType); + func->setFullType(newFuncType); + } + } } }; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index d859f86a6..f514fea1d 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1030,6 +1030,26 @@ IRInst* getInstInBlock(IRInst* inst) return getInstInBlock(inst->getParent()); } +ShortList<IRInst*> getPhiArgs(IRInst* phiParam) +{ + ShortList<IRInst*> result; + auto block = cast<IRBlock>(phiParam->getParent()); + UInt paramIndex = 0; + for (auto p = block->getFirstParam(); p; p = p->getNextParam()) + { + if (p == phiParam) + break; + paramIndex++; + } + for (auto predBlock : block->getPredecessors()) + { + auto termInst = as<IRUnconditionalBranch>(predBlock->getTerminator()); + SLANG_ASSERT(paramIndex < termInst->getArgCount()); + result.add(termInst->getArg(paramIndex)); + } + return result; +} + void removePhiArgs(IRInst* phiParam) { auto block = cast<IRBlock>(phiParam->getParent()); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index c76898aa2..c290f9392 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -221,6 +221,8 @@ IRInst* getInstInBlock(IRInst* inst); void removePhiArgs(IRInst* phiParam); +ShortList<IRInst*> getPhiArgs(IRInst* phiParam); + int getParamIndexInBlock(IRParam* paramInst); bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 94de28089..035b2aade 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -17,7 +17,14 @@ namespace Slang SourceLoc const& getDiagnosticPos(IRInst* inst) { - return inst->sourceLoc; + while (inst) + { + if (inst->sourceLoc.isValid()) + return inst->sourceLoc; + inst = inst->parent; + } + static SourceLoc invalid = SourceLoc(); + return invalid; } void printDiagnosticArg(StringBuilder& sb, IRInst* irObject) @@ -4900,7 +4907,7 @@ namespace Slang IRType* type = nullptr; auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType()); auto valueType = unwrapAttributedType(basePtrType->getValueType()); - if (auto arrayType = as<IRArrayType>(valueType)) + if (auto arrayType = as<IRArrayTypeBase>(valueType)) { type = arrayType->getElementType(); } @@ -5507,6 +5514,28 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitCastPtrToInt(IRInst* val) + { + auto inst = createInst<IRInst>( + this, + kIROp_CastPtrToInt, + getUInt64Type(), + val); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitCastIntToPtr(IRType* ptrType, IRInst* val) + { + auto inst = createInst<IRInst>( + this, + kIROp_CastIntToPtr, + ptrType, + val); + addInst(inst); + return inst; + } + IRGlobalConstant* IRBuilder::emitGlobalConstant( IRType* type) { @@ -7873,6 +7902,7 @@ namespace Slang case kIROp_FieldAddress: case kIROp_GetElement: case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: case kIROp_UpdateElement: case kIROp_MeshOutputRef: case kIROp_MakeVectorFromScalar: @@ -7910,6 +7940,7 @@ namespace Slang case kIROp_FloatCast: case kIROp_CastPtrToInt: case kIROp_CastIntToPtr: + case kIROp_PtrCast: case kIROp_AllocObj: case kIROp_PackAnyValue: case kIROp_UnpackAnyValue: @@ -8319,6 +8350,7 @@ namespace Slang case kIROp_FieldAddress: case kIROp_GetElement: case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: case kIROp_UpdateElement: case kIROp_Specialize: case kIROp_LookupWitness: @@ -8347,6 +8379,7 @@ namespace Slang case kIROp_CastIntToPtr: case kIROp_CastPtrToBool: case kIROp_CastPtrToInt: + case kIROp_PtrCast: case kIROp_BitAnd: case kIROp_BitNot: case kIROp_BitOr: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 0766bb168..bbb9dfeeb 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1813,6 +1813,10 @@ struct IRStructField : IRInst // return (IRType*) getOperand(1); } + void setFieldType(IRType* type) + { + setOperand(1, type); + } IR_LEAF_ISA(StructField) }; diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 3da4f8554..13db8af18 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -358,6 +358,11 @@ public: return dispatchIfNotNull(expr->baseExpression); } + bool visitOpenRefExpr(OpenRefExpr* expr) + { + return dispatchIfNotNull(expr->innerExpr); + } + bool visitInitializerListExpr(InitializerListExpr* expr) { for (auto arg : expr->args) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4e8c9b340..f5d743bb1 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4264,6 +4264,10 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeArrayFromElement(irType, irDefaultElement)); } + else if (auto ptrType = as<PtrType>(type)) + { + return LoweredValInfo::simple(getBuilder()->getNullPtrValue(irType)); + } else if (auto declRefType = as<DeclRefType>(type)) { DeclRef<Decl> declRef = declRefType->getDeclRef(); diff --git a/tests/spirv/pointer.slang b/tests/spirv/pointer.slang new file mode 100644 index 000000000..cb2d56f66 --- /dev/null +++ b/tests/spirv/pointer.slang @@ -0,0 +1,48 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry main -stage compute -emit-spirv-directly + + +struct PP +{ + int data; + int data2; +} +struct Data +{ + int data; + PP* pNext; +}; + +void funcThatTakesPointer(PP* p) +{ + p.data = 2; +} +int* funcThatReturnsPointer(PP* p) +{ + return &p.data; +} + +// CHECK: OpEntryPoint + +StructuredBuffer<Data> buffer; +RWStructuredBuffer<int> output; +void main(int id : SV_DispatchThreadID) +{ + output[0] = buffer[0].pNext.data; + let pData = &(buffer[0].pNext.data); + // CHECK: OpPtrAccessChain + int* pData1 = pData + 1; + *pData1 = 3; + *(int2*)pData = int2(1, 2); + pData1[-1] = 2; + buffer[0].pNext[1] = {5}; + // CHECK: OpConvertPtrToU + // CHECK: OpINotEqual + if (pData1) + { + *(funcThatReturnsPointer(buffer[0].pNext)) = 4; + } + if (pData1 > pData) + { + funcThatTakesPointer(buffer[0].pNext); + } +} |
