summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-02-08 18:29:32 -0800
committerGitHub <noreply@github.com>2024-02-08 18:29:32 -0800
commitf44da6cc5c0f211c13bd1eb0743d79c7861ea64e (patch)
tree3ad4edb5e7806c41003280ebf60fd6419a742105
parenta16f712bb99e426519c9a556b17b54bcc4d1d22d (diff)
Support pointers in SPIRV. (#3561)
* Support pointers in SPIRV. * Fix test. * Enhance test. * Fix test. * Cleanup.
-rw-r--r--source/slang/core.meta.slang24
-rw-r--r--source/slang/hlsl.meta.slang12
-rw-r--r--source/slang/slang-ast-iterator.h4
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-ast-type.h3
-rw-r--r--source/slang/slang-check-expr.cpp5
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-c-like.cpp13
-rw-r--r--source/slang/slang-emit-spirv-ops.h31
-rw-r--r--source/slang/slang-emit-spirv.cpp201
-rw-r--r--source/slang/slang-ir-constexpr.cpp1
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h3
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--source/slang/slang-ir-liveness.cpp1
-rw-r--r--source/slang/slang-ir-peephole.cpp2
-rw-r--r--source/slang/slang-ir-specialize.cpp5
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp330
-rw-r--r--source/slang/slang-ir-util.cpp20
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-ir.cpp37
-rw-r--r--source/slang/slang-ir.h4
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
-rw-r--r--tests/spirv/pointer.slang48
25 files changed, 705 insertions, 58 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f3ab38582..1e1ef061e 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -719,19 +719,7 @@ struct Ptr
__subscript(int index) -> T
{
- [__unsafeForceInlineEarly]
- get
- {
- return __load(__getElementPtr(this, index));
- }
-
- [__unsafeForceInlineEarly]
- set(T newValue)
- {
- __store(__getElementPtr(this, index), newValue);
- }
-
- __intrinsic_op($(kIROp_GetElementPtr))
+ __intrinsic_op($(kIROp_GetOffsetPtr))
ref;
}
};
@@ -748,6 +736,12 @@ Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int index);
__intrinsic_op($(kIROp_GetElementPtr))
Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int64_t index);
+__intrinsic_op($(kIROp_GetOffsetPtr))
+Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int index);
+
+__intrinsic_op($(kIROp_GetOffsetPtr))
+Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int64_t index);
+
__generic<T>
__intrinsic_op($(kIROp_Less))
bool operator<(Ptr<T> p1, Ptr<T> p2);
@@ -1543,14 +1537,14 @@ __intrinsic_op(0)
__prefix Ptr<T> operator&(__ref T value);
__generic<T>
-__intrinsic_op($(kIROp_GetElementPtr))
+__intrinsic_op($(kIROp_GetOffsetPtr))
Ptr<T> operator+(Ptr<T> value, int64_t offset);
__generic<T>
[__unsafeForceInlineEarly]
Ptr<T> operator-(Ptr<T> value, int64_t offset)
{
- return __getElementPtr(value, -offset);
+ return __getOffsetPtr(value, -offset);
}
__generic<T : IArithmetic>
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 8183c2030..156ecc194 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -13164,12 +13164,6 @@ struct ConstBufferPointer
}
}
- __subscript(int index) -> T
- {
- [ForceInline]
- get {return ConstBufferPointer<T>.fromUInt(toUInt() + __naturalStrideOf<T>() * index).get(); }
- }
-
__glsl_version(450)
__glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64)
__glsl_extension(GL_EXT_buffer_reference)
@@ -13221,4 +13215,10 @@ struct ConstBufferPointer
};
}
}
+
+ __subscript(int index)->T
+ {
+ [ForceInline]
+ get { return ConstBufferPointer<T>.fromUInt(toUInt() + __naturalStrideOf<T>() * index).get(); }
+ }
}
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index e2d0638e0..fc6f321e3 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -52,6 +52,10 @@ struct ASTIterator
{
iterator->maybeDispatchCallback(expr);
}
+ void visitOpenRefExpr(OpenRefExpr* expr)
+ {
+ dispatchIfNotNull(expr->innerExpr);
+ }
void visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr)
{
iterator->maybeDispatchCallback(expr);
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 882e26078..c1984910c 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -109,6 +109,7 @@ namespace Slang
kConversionCost_InRangeIntLitSignedToUnsignedConversion = 32,
kConversionCost_InRangeIntLitUnsignedToSignedConversion = 81,
+ kConversionCost_MutablePtrToConstPtr = 20,
// Conversions based on explicit sub-typing relationships are the cheapest
//
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 1d2ebc566..d47e3a496 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -535,7 +535,8 @@ class PtrType : public PtrTypeBase
SLANG_AST_CLASS(PtrType)
};
-// A GPU pointer type that for general readonly memory access.
+// A GPU pointer type into global memory.
+
class ConstBufferPointerType : public PtrTypeBase
{
SLANG_AST_CLASS(ConstBufferPointerType)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f9adcc91a..2f4906826 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -479,7 +479,10 @@ namespace Slang
derefExpr->base = base;
derefExpr->type = QualType(elementType);
- derefExpr->type.isLeftValue = base->type.isLeftValue;
+ if (as<PtrType>(base->type))
+ derefExpr->type.isLeftValue = true;
+ else
+ derefExpr->type.isLeftValue = base->type.isLeftValue;
return derefExpr;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 62bc73c90..9b599ae2e 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -782,6 +782,8 @@ DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(58001, Error, entryPointMustReturnVoidWhenGlobalOutputPresent, "entry point must return 'void' when global output variables are present.")
DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage Buffer Object contents, unsized arrays as a final parameter must be the only parameter")
+DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.")
+
//
// 8xxxx - Issues specific to a particular library/technology/platform/etc.
//
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 05c525965..1c01478ed 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -2055,6 +2055,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
case kIROp_MatrixReshape:
case kIROp_CastPtrToInt:
case kIROp_CastIntToPtr:
+ case kIROp_PtrCast:
{
// Simple constructor call
auto prec = getInfo(EmitOp::Prefix);
@@ -2345,6 +2346,15 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
m_writer->emit(".detach()");
break;
}
+ case kIROp_GetOffsetPtr:
+ {
+ auto prec = getInfo(EmitOp::Add);
+ needClose = maybeEmitParens(outerPrec, prec);
+ emitOperand(inst->getOperand(0), leftSide(outerPrec, prec));
+ m_writer->emit(" + ");
+ emitOperand(inst->getOperand(1), rightSide(prec, outerPrec));
+ break;
+ }
case kIROp_GetElement:
case kIROp_GetElementPtr:
case kIROp_ImageSubscript:
@@ -4097,7 +4107,8 @@ void CLikeSourceEmitter::ensureGlobalInst(ComputeEmitActionsContext* ctx, IRInst
}
if (as<IRBasicType>(inst))
return;
-
+ if (as<IRPtrLit>(inst))
+ return;
// Certain inst ops will always emit as definition.
switch (inst->getOp())
{
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 891372fa6..32f47b3ef 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -362,6 +362,20 @@ SpvInst* emitOpTypeStruct(IRInst* inst, const Ts& member0TypeMember1TypeEtc)
);
}
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeForwardPointer
+template<typename T>
+SpvInst* emitOpTypeForwardPointer(const T& type, SpvStorageClass storageClass)
+{
+ static_assert(isSingular<T>);
+ return emitInst(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ nullptr,
+ SpvOpTypeForwardPointer,
+ type,
+ storageClass
+ );
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypePointer
template<typename T>
SpvInst* emitOpTypePointer(IRInst* inst, SpvStorageClass storageClass, const T& type)
@@ -623,6 +637,23 @@ SpvInst* emitOpAccessChain(
return emitInst(parent, inst, SpvOpAccessChain, idResultType, kResultID, base, indexes);
}
+
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpPtrAccessChain
+template<typename T1, typename T2, typename T3>
+SpvInst* emitOpPtrAccessChain(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T1& idResultType,
+ const T2& base,
+ const T3& element
+)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ return emitInst(parent, inst, SpvOpPtrAccessChain, idResultType, kResultID, base, element);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
template<typename T>
SpvInst* emitOpDecorate(
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 6e1e58755..d5d00e417 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -135,7 +135,6 @@ public:
/// Dump all children, recursively, to a flattened list of SPIR-V words
void dumpTo(List<SpvWord>& ioWords);
-private:
/// The first child, if any.
SpvInst* m_firstChild = nullptr;
@@ -145,7 +144,7 @@ private:
/// while if it is non-empty it points to the `nextSibling` field
/// of the last instruction.
///
- SpvInst** m_link = &m_firstChild;
+ SpvInst* m_lastChild = nullptr;
};
// A SPIR-V instruction is then (in the general case) a potential
@@ -198,9 +197,13 @@ struct SpvInst : SpvInstParent
// We will store the instructions in a given `SpvInstParent`
// using an intrusive linked list.
+ SpvInstParent* parent = nullptr;
+
/// The next instruction in the same `SpvInstParent`
SpvInst* nextSibling = nullptr;
+ SpvInst* prevSibling = nullptr;
+
/// The result <id> produced by this instruction, or zero if it has no result.
SpvWord id = 0;
@@ -235,6 +238,43 @@ struct SpvInst : SpvInstParent
//
SpvInstParent::dumpTo(ioWords);
}
+
+ void removeFromParent()
+ {
+ auto oldParent = parent;
+
+ // If we don't currently have a parent, then
+ // we are doing fine.
+ if (!oldParent)
+ return;
+
+ auto pp = prevSibling;
+ auto nn = nextSibling;
+
+ if (pp)
+ {
+ SLANG_ASSERT(pp->parent == oldParent);
+ pp->nextSibling = nn;
+ }
+ else
+ {
+ oldParent->m_firstChild = nn;
+ }
+
+ if (nn)
+ {
+ SLANG_ASSERT(nn->parent == oldParent);
+ nn->prevSibling = pp;
+ }
+ else
+ {
+ oldParent->m_lastChild = pp;
+ }
+
+ prevSibling = nullptr;
+ nextSibling = nullptr;
+ parent = nullptr;
+ }
};
/// A logical section of a SPIR-V module
@@ -248,15 +288,22 @@ struct SpvLogicalSection : SpvInstParent
void SpvInstParent::addInst(SpvInst* inst)
{
SLANG_ASSERT(inst);
+ SLANG_ASSERT(!inst->nextSibling);
+
+ if (m_firstChild == nullptr)
+ {
+ m_firstChild = m_lastChild = inst;
+ return;
+ }
// The user shouldn't be trying to add multiple instructions at once.
// If they really want that then they probably wanted to give `inst`
// some children.
//
- SLANG_ASSERT(!inst->nextSibling);
-
- *m_link = inst;
- m_link = &inst->nextSibling;
+ m_lastChild->nextSibling = inst;
+ inst->prevSibling = m_lastChild;
+ inst->parent = this;
+ m_lastChild = inst;
}
void SpvInstParent::dumpTo(List<SpvWord>& ioWords)
@@ -429,6 +476,11 @@ struct SPIRVEmitContext
/// The next destination `<id>` to allocate.
SpvWord m_nextID = 1;
+ OrderedHashSet<IRPtrTypeBase*> m_forwardDeclaredPointers;
+
+ // A hash set to prevent redecorating the same spv inst.
+ HashSet<SpvId> m_decoratedSpvInsts;
+
SpvAddressingModel m_addressingMode = SpvAddressingModelLogical;
// We will store the logical sections of the SPIR-V module
@@ -1244,6 +1296,17 @@ struct SPIRVEmitContext
return m_targetRequest->getHLSLToVulkanLayoutOptions()->shouldEmitSPIRVReflectionInfo();
}
+ void requireVariablePointers()
+ {
+ if (m_addressingMode == SpvAddressingModelPhysicalStorageBuffer64)
+ return;
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers"));
+ requireSPIRVCapability(SpvCapabilityVariablePointers);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_physical_storage_buffer"));
+ requireSPIRVCapability(SpvCapabilityPhysicalStorageBufferAddresses);
+ m_addressingMode = SpvAddressingModelPhysicalStorageBuffer64;
+ }
+
// Next, let's look at emitting some of the instructions
// that can occur at global scope.
@@ -1312,11 +1375,41 @@ struct SPIRVEmitContext
storageClass = (SpvStorageClass)ptrType->getAddressSpace();
if (storageClass == SpvStorageClassStorageBuffer)
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class"));
- return emitOpTypePointer(
+ if (storageClass == SpvStorageClassPhysicalStorageBuffer)
+ {
+ requireVariablePointers();
+ }
+ auto valueType = ptrType->getValueType();
+ // If we haven't emitted the inner type yet, we need to emit a forward declaration.
+ bool useForwardDeclaration = (!m_mapIRInstToSpvInst.containsKey(valueType)
+ && as<IRStructType>(valueType)
+ && storageClass == SpvStorageClassPhysicalStorageBuffer);
+ auto resultSpvType = emitOpTypePointer(
inst,
storageClass,
- inst->getOperand(0)
+ useForwardDeclaration? getIRInstSpvID(valueType) : getID(ensureInst(valueType))
);
+ if (useForwardDeclaration)
+ {
+ // After everything has been emitted, we will move the pointer definition to the end
+ // of the Types & Constants section.
+ if (m_forwardDeclaredPointers.add(ptrType))
+ emitOpTypeForwardPointer(resultSpvType, storageClass);
+ }
+ if (storageClass == SpvStorageClassPhysicalStorageBuffer)
+ {
+ if (m_decoratedSpvInsts.add(getID(resultSpvType)))
+ {
+ IRSizeAndAlignment sizeAndAlignment;
+ getNaturalSizeAndAlignment(m_targetRequest, ptrType->getValueType(), &sizeAndAlignment);
+ emitOpDecorateArrayStride(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ resultSpvType,
+ SpvLiteralInteger::from32((uint32_t)sizeAndAlignment.getStride()));
+ }
+ }
+ return resultSpvType;
}
case kIROp_ConstantBufferType:
SLANG_UNEXPECTED("Constant buffer type remaining in spirv emit");
@@ -1404,11 +1497,7 @@ struct SPIRVEmitContext
return emitOpTypeHitObject(inst);
case kIROp_HLSLConstBufferPointerType:
- ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers"));
- requireSPIRVCapability(SpvCapabilityVariablePointers);
- ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_physical_storage_buffer"));
- requireSPIRVCapability(SpvCapabilityPhysicalStorageBufferAddresses);
- m_addressingMode = SpvAddressingModelPhysicalStorageBuffer64;
+ requireVariablePointers();
return emitOpTypePointer(inst, SpvStorageClassPhysicalStorageBuffer, inst->getOperand(0));
case kIROp_FuncType:
@@ -1446,6 +1535,7 @@ struct SPIRVEmitContext
case kIROp_IntLit:
case kIROp_FloatLit:
case kIROp_StringLit:
+ case kIROp_PtrLit:
{
return emitLit(inst);
}
@@ -1978,6 +2068,7 @@ struct SPIRVEmitContext
param->getDataType(),
storageClass
);
+ maybeEmitPointerDecoration(varInst, param);
if (auto layout = getVarLayout(param))
emitVarLayout(param, varInst, layout);
maybeEmitName(varInst, param);
@@ -2001,6 +2092,7 @@ struct SPIRVEmitContext
globalVar->getDataType(),
storageClass
);
+ maybeEmitPointerDecoration(varInst, globalVar);
if(layout)
emitVarLayout(globalVar, varInst, layout);
maybeEmitName(varInst, globalVar);
@@ -2274,6 +2366,8 @@ struct SPIRVEmitContext
return emitFieldExtract(parent, as<IRFieldExtract>(inst));
case kIROp_GetElementPtr:
return emitGetElementPtr(parent, as<IRGetElementPtr>(inst));
+ case kIROp_GetOffsetPtr:
+ return emitGetOffsetPtr(parent, inst);
case kIROp_GetElement:
return emitGetElement(parent, as<IRGetElement>(inst));
case kIROp_MakeStruct:
@@ -2306,6 +2400,13 @@ struct SPIRVEmitContext
return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst));
case kIROp_CastFloatToInt:
return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst));
+ case kIROp_CastPtrToInt:
+ return emitCastPtrToInt(parent, inst);
+ case kIROp_CastPtrToBool:
+ return emitCastPtrToBool(parent, inst);
+ case kIROp_CastIntToPtr:
+ return emitCastIntToPtr(parent, inst);
+ case kIROp_PtrCast:
case kIROp_BitCast:
return emitOpBitcast(
parent,
@@ -3403,10 +3504,27 @@ struct SPIRVEmitContext
return nullptr;
}
+ void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst)
+ {
+ auto ptrType = as<IRPtrType>(inst->getDataType());
+ if (!ptrType)
+ return;
+ if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ {
+ emitOpDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ varInst,
+ (as<IRVar>(inst) ? SpvDecorationAliasedPointer : SpvDecorationAliased)
+ );
+ }
+ }
+
SpvInst* emitParam(SpvInstParent* parent, IRInst* inst)
{
auto paramSpvInst = emitOpFunctionParameter(parent, inst, inst->getFullType());
maybeEmitName(paramSpvInst, inst);
+ maybeEmitPointerDecoration(paramSpvInst, inst);
return paramSpvInst;
}
@@ -3421,6 +3539,7 @@ struct SPIRVEmitContext
}
auto varSpvInst = emitOpVariable(parent, inst, inst->getFullType(), storageClass);
maybeEmitName(varSpvInst, inst);
+ maybeEmitPointerDecoration(varSpvInst, inst);
return varSpvInst;
}
@@ -3962,6 +4081,11 @@ struct SPIRVEmitContext
);
}
+ SpvInst* emitGetOffsetPtr(SpvInstParent* parent, IRInst* inst)
+ {
+ return emitOpPtrAccessChain(parent, inst, inst->getDataType(), inst->getOperand(0), inst->getOperand(1));
+ }
+
SpvInst* emitGetElementPtr(SpvInstParent* parent, IRGetElementPtr* inst)
{
IRBuilder builder(m_irModule);
@@ -4025,12 +4149,32 @@ struct SPIRVEmitContext
SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst)
{
- return emitOpLoad(parent, inst, inst->getDataType(), inst->getPtr());
+ auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType());
+ if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ {
+ IRSizeAndAlignment sizeAndAlignment;
+ getNaturalSizeAndAlignment(m_targetRequest, ptrType->getValueType(), &sizeAndAlignment);
+ return emitOpLoadAligned(parent, inst, inst->getDataType(), inst->getPtr(), SpvLiteralInteger::from32(sizeAndAlignment.alignment));
+ }
+ else
+ {
+ return emitOpLoad(parent, inst, inst->getDataType(), inst->getPtr());
+ }
}
SpvInst* emitStore(SpvInstParent* parent, IRStore* inst)
{
- return emitOpStore(parent, inst, inst->getPtr(), inst->getVal());
+ auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType());
+ if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ {
+ IRSizeAndAlignment sizeAndAlignment;
+ getNaturalSizeAndAlignment(m_targetRequest, ptrType->getValueType(), &sizeAndAlignment);
+ return emitOpStoreAligned(parent, inst, inst->getPtr(), inst->getVal(), SpvLiteralInteger::from32(sizeAndAlignment.alignment));
+ }
+ else
+ {
+ return emitOpStore(parent, inst, inst->getPtr(), inst->getVal());
+ }
}
SpvInst* emitSwizzledStore(SpvInstParent* parent, IRSwizzledStore* inst)
@@ -4322,6 +4466,23 @@ struct SPIRVEmitContext
: emitOpConvertFToU(parent, inst, toTypeV, inst->getOperand(0));
}
+ SpvInst* emitCastPtrToInt(SpvInstParent* parent, IRInst* inst)
+ {
+ return emitInst(parent, inst, SpvOpConvertPtrToU, inst->getFullType(), kResultID, inst->getOperand(0));
+ }
+
+ SpvInst* emitCastPtrToBool(SpvInstParent* parent, IRInst* inst)
+ {
+ IRBuilder builder(inst);
+ auto uintVal = emitInst(parent, nullptr, SpvOpConvertPtrToU, builder.getUInt64Type(), kResultID, inst->getOperand(0));
+ return emitOpINotEqual(parent, inst, kResultID, uintVal, builder.getIntValue(builder.getUInt64Type(), 0));
+ }
+
+ SpvInst* emitCastIntToPtr(SpvInstParent* parent, IRInst* inst)
+ {
+ return emitInst(parent, inst, SpvOpConvertUToPtr, inst->getFullType(), kResultID, inst->getOperand(0));
+ }
+
template<typename T, typename Ts>
SpvInst* emitCompositeConstruct(
SpvInstParent* parent,
@@ -5124,6 +5285,16 @@ SlangResult emitSPIRVFromIR(
{
context.ensureInst(irEntryPoint);
}
+
+ // Move forward delcared pointers to the end.
+ for (auto ptrType : context.m_forwardDeclaredPointers)
+ {
+ auto spvPtrType = context.m_mapIRInstToSpvInst[ptrType];
+ auto parent = spvPtrType->parent;
+ spvPtrType->removeFromParent();
+ parent->addInst(spvPtrType);
+ }
+
context.emitFrontMatter();
context.emitPhysicalLayout();
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index 63ca32650..7a93a312c 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -112,6 +112,7 @@ bool opCanBeConstExpr(IROp op)
case kIROp_CastIntToPtr:
case kIROp_CastPtrToInt:
case kIROp_CastPtrToBool:
+ case kIROp_PtrCast:
case kIROp_Reinterpret:
case kIROp_BitCast:
case kIROp_MakeTuple:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index e183058ac..0c962b7a4 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -386,6 +386,8 @@ INST(FieldAddress, get_field_addr, 2, 0)
INST(GetElement, getElement, 2, 0)
INST(GetElementPtr, getElementPtr, 2, 0)
+// Pointer offset: computes pBase + offset_in_elements
+INST(GetOffsetPtr, getOffsetPtr, 2, 0)
INST(GetAddr, getAddr, 1, 0)
// Get an unowned NativeString from a String.
@@ -1011,6 +1013,7 @@ INST(CastPtrToBool, CastPtrToBool, 1, 0)
INST(CastPtrToInt, CastPtrToInt, 1, 0)
INST(CastIntToPtr, CastIntToPtr, 1, 0)
INST(CastToVoid, castToVoid, 1, 0)
+INST(PtrCast, PtrCast, 1, 0)
INST(SizeOf, sizeOf, 1, 0)
INST(AlignOf, alignOf, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 6e3821c18..82d891459 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2216,7 +2216,6 @@ struct IRFieldAddress : IRInst
IRInst* getBase() { return base.get(); }
IRInst* getField() { return field.get(); }
IR_LEAF_ISA(FieldAddress)
-
};
struct IRGetElement : IRInst
@@ -4065,6 +4064,8 @@ public:
IRInst* sizedType);
IRInst* emitCastPtrToBool(IRInst* val);
+ IRInst* emitCastPtrToInt(IRInst* val);
+ IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val);
IRGlobalConstant* emitGlobalConstant(
IRType* type);
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 36769cc34..eb0068657 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -276,7 +276,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
{
IRConstant* c = (IRConstant*)originalValue;
SLANG_RELEASE_ASSERT(c->value.ptrVal == nullptr);
- return builder->getNullVoidPtrValue();
+ return builder->getNullPtrValue(cloneType(this, c->getFullType()));
}
break;
diff --git a/source/slang/slang-ir-liveness.cpp b/source/slang/slang-ir-liveness.cpp
index 9cd4462af..28cd64a08 100644
--- a/source/slang/slang-ir-liveness.cpp
+++ b/source/slang/slang-ir-liveness.cpp
@@ -1029,6 +1029,7 @@ bool LivenessContext::_isAccessTerminator(IRTerminatorInst* terminator)
case kIROp_CastIntToPtr:
case kIROp_CastPtrToInt:
case kIROp_CastPtrToBool:
+ case kIROp_PtrCast:
val = val->getOperand(0);
break;
}
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index fb67c6842..39a137490 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -581,7 +581,7 @@ struct PeepholeContext : InstPassBase
auto ptr = inst->getOperand(0);
IRBuilder builder(module);
builder.setInsertBefore(inst);
- auto neq = builder.emitNeq(ptr, builder.getNullVoidPtrValue());
+ auto neq = builder.emitNeq(ptr, builder.getNullPtrValue(ptr->getDataType()));
inst->replaceUsesWith(neq);
maybeRemoveOldInst(inst);
changed = true;
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index b82daa9a4..60001661c 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -1995,11 +1995,6 @@ struct SpecializationContext
{
return 2;
}
- else if (auto ptrType = as<IRPtrTypeBase>(type))
- {
- type = ptrType->getValueType();
- goto top;
- }
else if (auto ptrLikeType = as<IRPointerLikeType>(type))
{
type = ptrLikeType->getElementType();
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 1675fe279..474ebc71c 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -741,6 +741,66 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return result;
}
+ void processVar(IRInst* inst)
+ {
+ auto oldPtrType = as<IRPtrType>(inst->getDataType());
+ if (!oldPtrType->hasAddressSpace())
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newPtrType = builder.getPtrType(
+ oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction);
+ inst->setFullType(newPtrType);
+ addUsersToWorkList(inst);
+ }
+ }
+
+ void processParam(IRInst* inst)
+ {
+ auto block = getBlock(inst);
+ auto func = getParentFunc(block);
+ if (!block || !func)
+ return;
+ auto oldPtrType = as<IRPtrType>(inst->getDataType());
+ if (!oldPtrType)
+ return;
+ if (!oldPtrType->hasAddressSpace())
+ {
+ SpvStorageClass addressSpace = (SpvStorageClass)-1;
+
+ if (block == func->getFirstBlock())
+ {
+ // A pointer typed function parameter should always be in the storage buffer address space.
+ addressSpace = SpvStorageClassPhysicalStorageBuffer;
+ }
+ else
+ {
+ // The address space of a phi inst should always be the same as arguments.
+ auto args = getPhiArgs(inst);
+ for (auto arg : args)
+ {
+ auto argPtrType = as<IRPtrType>(arg->getDataType());
+ if (argPtrType->hasAddressSpace())
+ {
+ if (addressSpace == (SpvStorageClass)-1)
+ addressSpace = (SpvStorageClass)argPtrType->getAddressSpace();
+ else if (addressSpace != argPtrType->getAddressSpace())
+ m_sharedContext->m_sink->diagnose(inst, Diagnostics::inconsistentPointerAddressSpace, inst);
+ }
+ }
+ }
+ if (addressSpace != (SpvStorageClass)-1)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newPtrType = builder.getPtrType(
+ oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassPhysicalStorageBuffer);
+ inst->setFullType(newPtrType);
+ addUsersToWorkList(inst);
+ }
+ }
+ }
+
void processGlobalVar(IRInst* inst)
{
auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType());
@@ -844,6 +904,16 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
for (UInt i = 0; i < inst->getArgCount(); i++)
{
auto arg = inst->getArg(i);
+ auto paramType = funcType->getParamType(i);
+ if (as<IRPtrType>(paramType))
+ {
+ // If the parameter has an explicit pointer type,
+ // then we know the user is using the variable pointer
+ // capability to pass a true pointer.
+ // In this case we should not rewrite the call.
+ newArgs.add(arg);
+ continue;
+ }
auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
if (!as<IRPtrTypeBase>(arg->getDataType()))
{
@@ -898,7 +968,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount());
if (writeBacks.getCount())
{
- auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs);
+ auto newCall = builder.emitCallInst(
+ translateToStorageBufferPointer(inst->getFullType()),
+ inst->getCallee(),
+ newArgs);
for (auto wb : writeBacks)
{
auto newVal = builder.emitLoad(wb.tempVar);
@@ -908,6 +981,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
inst->removeAndDeallocate();
addUsersToWorkList(newCall);
}
+ else
+ {
+ translatePtrResultType(inst);
+ }
}
Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar;
@@ -989,6 +1066,28 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
processGetElementPtrImpl(gepInst, gepInst->getBase(), gepInst->getIndex());
}
+ void processGetOffsetPtr(IRInst* offsetPtrInst)
+ {
+ auto ptrOperandType = as<IRPtrType>(offsetPtrInst->getOperand(0)->getDataType());
+ if (!ptrOperandType)
+ return;
+ if (!ptrOperandType->hasAddressSpace())
+ return;
+ auto resultPtrType = as<IRPtrType>(offsetPtrInst->getDataType());
+ if (!resultPtrType)
+ return;
+ if (resultPtrType->getAddressSpace() != ptrOperandType->getAddressSpace())
+ {
+ IRBuilder builder(offsetPtrInst);
+ builder.setInsertBefore(offsetPtrInst);
+ auto newResultType = builder.getPtrType(resultPtrType->getOp(),
+ resultPtrType->getValueType(),
+ ptrOperandType->getAddressSpace());
+ auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType);
+ addUsersToWorkList(newInst);
+ }
+ }
+
void processStructuredBufferLoad(IRInst* loadInst)
{
auto sb = loadInst->getOperand(0);
@@ -1060,13 +1159,16 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (!ptrType->hasAddressSpace())
return;
auto oldResultType = as<IRPtrTypeBase>(inst->getDataType());
- if (oldResultType->getAddressSpace() != ptrType->getAddressSpace())
+ auto oldValueType = oldResultType->getValueType();
+ auto newValueType = translateToStorageBufferPointer(oldValueType);
+
+ if (oldValueType != newValueType || oldResultType->getAddressSpace() != ptrType->getAddressSpace())
{
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto newPtrType = builder.getPtrType(
oldResultType->getOp(),
- oldResultType->getValueType(),
+ newValueType,
ptrType->getAddressSpace());
auto newInst =
builder.emitFieldAddress(newPtrType, inst->getBase(), inst->getField());
@@ -1077,6 +1179,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
+ void processFieldExtract(IRFieldExtract* inst)
+ {
+ auto ptrType = as<IRPtrType>(inst->getDataType());
+ if (!ptrType)
+ return;
+ auto newPtrType = translateToStorageBufferPointer(ptrType);
+ if (newPtrType == ptrType)
+ return;
+ IRBuilder builder(inst);
+ auto newInst = builder.replaceOperand(&inst->typeUse, newPtrType);
+ addUsersToWorkList(newInst);
+ }
+
void duplicateMergeBlockIfNeeded(IRUse* breakBlockUse)
{
auto breakBlock = as<IRBlock>(breakBlockUse->get());
@@ -1106,7 +1221,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void processLoop(IRLoop* loop)
{
-
// 2.11.1. Rules for Structured Control-flow Declarations
// Structured control flow declarations must satisfy the following
// rules:
@@ -1186,6 +1300,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Insert a new continue block at the end of the loop
const auto newContinueBlock = builder.emitBlock();
+ addToWorkList(newContinueBlock);
+
newContinueBlock->insertBefore(loop->getBreakBlock());
// This block simply branches to the loop header, forwarding
@@ -1204,10 +1320,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
loop->block.set(t);
// Branch to the target in our new continue block
- builder.emitBranch(t, ps.getCount(), ps.getBuffer());
+ auto branch = builder.emitBranch(t, ps.getCount(), ps.getBuffer());
+ addToWorkList(branch);
}
}
duplicateMergeBlockIfNeeded(&loop->breakBlock);
+ addToWorkList(loop->getTargetBlock());
}
void processIfElse(IRIfElse* inst)
@@ -1223,6 +1341,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newBlock = builder.emitBlock();
builder.emitBranch(inst->getAfterBlock());
inst->trueBlock.set(newBlock);
+ addToWorkList(newBlock);
}
if (inst->getFalseBlock() == inst->getAfterBlock())
{
@@ -1230,6 +1349,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newBlock = builder.emitBlock();
builder.emitBranch(inst->getAfterBlock());
inst->falseBlock.set(newBlock);
+ addToWorkList(newBlock);
}
}
@@ -1246,6 +1366,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newBlock = builder.emitBlock();
builder.emitBranch(inst->getBreakLabel());
inst->defaultLabel.set(newBlock);
+ addToWorkList(newBlock);
}
for (UInt i = 0; i < inst->getCaseCount(); i++)
{
@@ -1255,6 +1376,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newBlock = builder.emitBlock();
builder.emitBranch(inst->getBreakLabel());
inst->getCaseLabelUse(i)->set(newBlock);
+ addToWorkList(newBlock);
}
}
}
@@ -1386,6 +1508,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
+ case kIROp_GetOffsetPtr:
case kIROp_UpdateElement:
case kIROp_MakeTuple:
case kIROp_GetTupleElement:
@@ -1407,6 +1530,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_CastIntToPtr:
+ case kIROp_PtrCast:
case kIROp_CastPtrToBool:
case kIROp_CastPtrToInt:
case kIROp_BitAnd:
@@ -1467,7 +1591,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
cloneEnv.mapOldValToNew[inst] = result;
return result;
}
-
+
// If the global value is inlinable, we make all its operands avaialble locally, and then copy it
// to the local scope.
ShortList<IRInst*> args;
@@ -1482,21 +1606,123 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return result;
}
- void processWorkList()
+ void processBranch(IRInst* branch)
+ {
+ addToWorkList(branch->getOperand(0));
+ }
+
+ IRType* translateToStorageBufferPointer(IRType* pointerType)
+ {
+ auto ptrType = as<IRPtrType>(pointerType);
+ if (!ptrType)
+ return pointerType;
+ auto oldValueType = ptrType->getValueType();
+ auto newValueType = translateToStorageBufferPointer(oldValueType);
+ if (oldValueType != newValueType || !ptrType->hasAddressSpace())
+ {
+ IRBuilder builder(m_module);
+ return builder.getPtrType(ptrType->getOp(), newValueType, SpvStorageClassPhysicalStorageBuffer);
+ }
+ return ptrType;
+ }
+
+ void translatePtrResultType(IRInst* inst)
+ {
+ auto ptrType = as<IRPtrType>(inst->getDataType());
+ auto newPtrType = translateToStorageBufferPointer(ptrType);
+ if (newPtrType == ptrType)
+ return;
+ IRBuilder builder(inst);
+ auto newInst = builder.replaceOperand(&inst->typeUse, newPtrType);
+ addUsersToWorkList(newInst);
+ }
+
+ void processPtrLit(IRInst* inst)
{
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newPtrType = translateToStorageBufferPointer(as<IRPtrType>(inst->getFullType()));
+ auto newInst = builder.emitCastIntToPtr(newPtrType, builder.getIntValue(builder.getUInt64Type(), 0));
+ inst->replaceUsesWith(newInst);
+ addUsersToWorkList(newInst);
+ }
+ void processPtrCast(IRInst* cast)
+ {
+ translatePtrResultType(cast);
+ }
+
+ void processLoad(IRInst* inst)
+ {
+ translatePtrResultType(inst);
+ }
+
+ void processStructField(IRStructField* field)
+ {
+ auto ptrType = as<IRPtrTypeBase>(field->getFieldType());
+ if (!ptrType)
+ return;
+ if (ptrType->hasAddressSpace())
+ return;
+ IRBuilder builder(field);
+ auto newPtrType = builder.getPtrType(
+ ptrType->getOp(),
+ ptrType->getValueType(),
+ SpvStorageClassPhysicalStorageBuffer);
+ field->setFieldType(newPtrType);
+ }
+
+ void processComparison(IRInst* inst)
+ {
+ auto operand0 = inst->getOperand(0);
+ if (as<IRPtrType>(operand0->getDataType()))
+ {
+ // If we are doing pointer comparison, convert the operands into uints first.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto castToUInt = [&](IRInst* operand)
+ {
+ if (as<IRPtrLit>(operand))
+ return builder.getIntValue(builder.getUInt64Type(), 0);
+ else
+ return builder.emitCastPtrToInt(operand);
+ };
+ auto newOperand0 = castToUInt(operand0);
+ SLANG_ASSERT(as<IRPtrType>(inst->getOperand(1)->getDataType()));
+ auto newOperand1 = castToUInt(inst->getOperand(1));
+ inst = builder.replaceOperand(inst->getOperands(), newOperand0);
+ inst = builder.replaceOperand(inst->getOperands() + 1, newOperand1);
+ }
+ }
+
+ void processWorkList()
+ {
while (workList.getCount() != 0)
{
IRInst* inst = workList.getLast();
workList.removeLast();
+
+ // Skip if inst has already been removed.
+ if (!inst->parent)
+ continue;
+
switch (inst->getOp())
{
+ case kIROp_StructField:
+ processStructField(as<IRStructField>(inst));
+ break;
case kIROp_GlobalParam:
processGlobalParam(as<IRGlobalParam>(inst));
break;
case kIROp_GlobalVar:
processGlobalVar(as<IRGlobalVar>(inst));
break;
+ case kIROp_Var:
+ processVar(as<IRVar>(inst));
+ break;
+ case kIROp_Param:
+ processParam(as<IRParam>(inst));
+ break;
case kIROp_Call:
processCall(as<IRCall>(inst));
break;
@@ -1506,9 +1732,15 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_GetElementPtr:
processGetElementPtr(as<IRGetElementPtr>(inst));
break;
+ case kIROp_GetOffsetPtr:
+ processGetOffsetPtr(inst);
+ break;
case kIROp_FieldAddress:
processFieldAddress(as<IRFieldAddress>(inst));
break;
+ case kIROp_FieldExtract:
+ processFieldExtract(as<IRFieldExtract>(inst));
+ break;
case kIROp_ImageSubscript:
processImageSubscript(as<IRImageSubscript>(inst));
break;
@@ -1533,7 +1765,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_Switch:
processSwitch(as<IRSwitch>(inst));
break;
-
+ case kIROp_Less:
+ case kIROp_Leq:
+ case kIROp_Eql:
+ case kIROp_Geq:
+ case kIROp_Greater:
+ case kIROp_Neq:
+ processComparison(inst);
+ break;
case kIROp_MakeVectorFromScalar:
case kIROp_MakeUInt64:
case kIROp_MakeVector:
@@ -1551,6 +1790,20 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_MakeOptionalNone:
processConstructor(inst);
break;
+ case kIROp_BitCast:
+ case kIROp_PtrCast:
+ case kIROp_CastIntToPtr:
+ processPtrCast(inst);
+ break;
+ case kIROp_PtrLit:
+ processPtrLit(inst);
+ break;
+ case kIROp_Load:
+ processLoad(inst);
+ break;
+ case kIROp_unconditionalBranch:
+ processBranch(inst);
+ break;
case kIROp_SPIRVAsm:
processSPIRVAsm(as<IRSPIRVAsm>(inst));
break;
@@ -1584,7 +1837,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void processModule()
{
- convertCompositeTypeParametersToPointers(m_module);
+ //convertCompositeTypeParametersToPointers(m_module);
// Process global params before anything else, so we don't generate inefficient
// array marhalling code for array-typed global params.
@@ -1631,6 +1884,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
t->replaceUsesWith(lowered);
}
+ // Inline global values that can't represented by SPIRV constant inst
+ // to their use sites.
List<IRUse*> globalInstUsesToInline;
for (auto globalInst : m_module->getGlobalInsts())
@@ -1666,6 +1921,63 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (val != use->get())
builder.replaceOperand(use, val);
}
+
+ // Some legalization processing may change the function parameter types,
+ // so we need to update the function types to match that.
+ updateFunctionTypes();
+ }
+
+ void updateFunctionTypes()
+ {
+ IRBuilder builder(m_module);
+ for (auto globalInst : m_module->getGlobalInsts())
+ {
+ auto func = as<IRFunc>(globalInst);
+ if (!func)
+ continue;
+ auto firstBlock = func->getFirstBlock();
+ if (!firstBlock)
+ continue;
+
+ builder.setInsertBefore(func);
+ auto type = func->getDataType();
+ auto oldFuncType = as<IRFuncType>(type);
+ auto resultType = oldFuncType->getResultType();
+ List<IRType*> newOperands;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (auto retInst = as<IRReturn>(inst))
+ {
+ resultType = retInst->getVal()->getFullType();
+ break;
+ }
+ }
+ }
+ for (auto param : firstBlock->getParams())
+ {
+ newOperands.add(param->getDataType());
+ }
+ bool changed = resultType != oldFuncType->getResultType();
+ if (!changed)
+ {
+ for (UInt i = 0; i < oldFuncType->getParamCount(); i++)
+ {
+ if (oldFuncType->getParamType(i) != newOperands[i])
+ {
+ changed = true;
+ break;
+ }
+ }
+ }
+ if (changed)
+ {
+ builder.setInsertBefore(func);
+ auto newFuncType = builder.getFuncType(newOperands, resultType);
+ func->setFullType(newFuncType);
+ }
+ }
}
};
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index d859f86a6..f514fea1d 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1030,6 +1030,26 @@ IRInst* getInstInBlock(IRInst* inst)
return getInstInBlock(inst->getParent());
}
+ShortList<IRInst*> getPhiArgs(IRInst* phiParam)
+{
+ ShortList<IRInst*> result;
+ auto block = cast<IRBlock>(phiParam->getParent());
+ UInt paramIndex = 0;
+ for (auto p = block->getFirstParam(); p; p = p->getNextParam())
+ {
+ if (p == phiParam)
+ break;
+ paramIndex++;
+ }
+ for (auto predBlock : block->getPredecessors())
+ {
+ auto termInst = as<IRUnconditionalBranch>(predBlock->getTerminator());
+ SLANG_ASSERT(paramIndex < termInst->getArgCount());
+ result.add(termInst->getArg(paramIndex));
+ }
+ return result;
+}
+
void removePhiArgs(IRInst* phiParam)
{
auto block = cast<IRBlock>(phiParam->getParent());
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index c76898aa2..c290f9392 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -221,6 +221,8 @@ IRInst* getInstInBlock(IRInst* inst);
void removePhiArgs(IRInst* phiParam);
+ShortList<IRInst*> getPhiArgs(IRInst* phiParam);
+
int getParamIndexInBlock(IRParam* paramInst);
bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 94de28089..035b2aade 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -17,7 +17,14 @@ namespace Slang
SourceLoc const& getDiagnosticPos(IRInst* inst)
{
- return inst->sourceLoc;
+ while (inst)
+ {
+ if (inst->sourceLoc.isValid())
+ return inst->sourceLoc;
+ inst = inst->parent;
+ }
+ static SourceLoc invalid = SourceLoc();
+ return invalid;
}
void printDiagnosticArg(StringBuilder& sb, IRInst* irObject)
@@ -4900,7 +4907,7 @@ namespace Slang
IRType* type = nullptr;
auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType());
auto valueType = unwrapAttributedType(basePtrType->getValueType());
- if (auto arrayType = as<IRArrayType>(valueType))
+ if (auto arrayType = as<IRArrayTypeBase>(valueType))
{
type = arrayType->getElementType();
}
@@ -5507,6 +5514,28 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitCastPtrToInt(IRInst* val)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_CastPtrToInt,
+ getUInt64Type(),
+ val);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitCastIntToPtr(IRType* ptrType, IRInst* val)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_CastIntToPtr,
+ ptrType,
+ val);
+ addInst(inst);
+ return inst;
+ }
+
IRGlobalConstant* IRBuilder::emitGlobalConstant(
IRType* type)
{
@@ -7873,6 +7902,7 @@ namespace Slang
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
+ case kIROp_GetOffsetPtr:
case kIROp_UpdateElement:
case kIROp_MeshOutputRef:
case kIROp_MakeVectorFromScalar:
@@ -7910,6 +7940,7 @@ namespace Slang
case kIROp_FloatCast:
case kIROp_CastPtrToInt:
case kIROp_CastIntToPtr:
+ case kIROp_PtrCast:
case kIROp_AllocObj:
case kIROp_PackAnyValue:
case kIROp_UnpackAnyValue:
@@ -8319,6 +8350,7 @@ namespace Slang
case kIROp_FieldAddress:
case kIROp_GetElement:
case kIROp_GetElementPtr:
+ case kIROp_GetOffsetPtr:
case kIROp_UpdateElement:
case kIROp_Specialize:
case kIROp_LookupWitness:
@@ -8347,6 +8379,7 @@ namespace Slang
case kIROp_CastIntToPtr:
case kIROp_CastPtrToBool:
case kIROp_CastPtrToInt:
+ case kIROp_PtrCast:
case kIROp_BitAnd:
case kIROp_BitNot:
case kIROp_BitOr:
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 0766bb168..bbb9dfeeb 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1813,6 +1813,10 @@ struct IRStructField : IRInst
//
return (IRType*) getOperand(1);
}
+ void setFieldType(IRType* type)
+ {
+ setOperand(1, type);
+ }
IR_LEAF_ISA(StructField)
};
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 3da4f8554..13db8af18 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -358,6 +358,11 @@ public:
return dispatchIfNotNull(expr->baseExpression);
}
+ bool visitOpenRefExpr(OpenRefExpr* expr)
+ {
+ return dispatchIfNotNull(expr->innerExpr);
+ }
+
bool visitInitializerListExpr(InitializerListExpr* expr)
{
for (auto arg : expr->args)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4e8c9b340..f5d743bb1 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4264,6 +4264,10 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return LoweredValInfo::simple(
getBuilder()->emitMakeArrayFromElement(irType, irDefaultElement));
}
+ else if (auto ptrType = as<PtrType>(type))
+ {
+ return LoweredValInfo::simple(getBuilder()->getNullPtrValue(irType));
+ }
else if (auto declRefType = as<DeclRefType>(type))
{
DeclRef<Decl> declRef = declRefType->getDeclRef();
diff --git a/tests/spirv/pointer.slang b/tests/spirv/pointer.slang
new file mode 100644
index 000000000..cb2d56f66
--- /dev/null
+++ b/tests/spirv/pointer.slang
@@ -0,0 +1,48 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry main -stage compute -emit-spirv-directly
+
+
+struct PP
+{
+ int data;
+ int data2;
+}
+struct Data
+{
+ int data;
+ PP* pNext;
+};
+
+void funcThatTakesPointer(PP* p)
+{
+ p.data = 2;
+}
+int* funcThatReturnsPointer(PP* p)
+{
+ return &p.data;
+}
+
+// CHECK: OpEntryPoint
+
+StructuredBuffer<Data> buffer;
+RWStructuredBuffer<int> output;
+void main(int id : SV_DispatchThreadID)
+{
+ output[0] = buffer[0].pNext.data;
+ let pData = &(buffer[0].pNext.data);
+ // CHECK: OpPtrAccessChain
+ int* pData1 = pData + 1;
+ *pData1 = 3;
+ *(int2*)pData = int2(1, 2);
+ pData1[-1] = 2;
+ buffer[0].pNext[1] = {5};
+ // CHECK: OpConvertPtrToU
+ // CHECK: OpINotEqual
+ if (pData1)
+ {
+ *(funcThatReturnsPointer(buffer[0].pNext)) = 4;
+ }
+ if (pData1 > pData)
+ {
+ funcThatTakesPointer(buffer[0].pNext);
+ }
+}