summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang95
-rw-r--r--source/slang/hlsl.meta.slang16
-rw-r--r--source/slang/slang-ast-builder.cpp14
-rw-r--r--source/slang/slang-ast-builder.h7
-rw-r--r--source/slang/slang-ast-type.cpp60
-rw-r--r--source/slang/slang-ast-type.h5
-rw-r--r--source/slang/slang-capabilities.capdef1
-rw-r--r--source/slang/slang-check-conversion.cpp1
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-c-like.cpp6
-rw-r--r--source/slang/slang-emit-c-like.h2
-rw-r--r--source/slang/slang-emit-cuda.cpp2
-rw-r--r--source/slang/slang-emit-cuda.h2
-rw-r--r--source/slang/slang-emit-glsl.cpp4
-rw-r--r--source/slang/slang-emit-glsl.h2
-rw-r--r--source/slang/slang-emit-hlsl.cpp2
-rw-r--r--source/slang/slang-emit-hlsl.h2
-rw-r--r--source/slang/slang-emit-metal.cpp4
-rw-r--r--source/slang/slang-emit-metal.h2
-rw-r--r--source/slang/slang-emit-spirv.cpp63
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp9
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp9
-rw-r--r--source/slang/slang-ir-composite-reg-to-mem.cpp1
-rw-r--r--source/slang/slang-ir-explicit-global-context.cpp2
-rw-r--r--source/slang/slang-ir-glsl-liveness.cpp2
-rw-r--r--source/slang/slang-ir-insts.h3
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp2
-rw-r--r--source/slang/slang-ir-simplify-for-emit.cpp2
-rw-r--r--source/slang/slang-ir-specialize-address-space.cpp9
-rw-r--r--source/slang/slang-ir-specialize-address-space.h4
-rw-r--r--source/slang/slang-ir-specialize.cpp6
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp126
-rw-r--r--source/slang/slang-ir.cpp54
-rw-r--r--source/slang/slang-ir.h17
-rw-r--r--source/slang/slang-legalize-types.cpp4
-rw-r--r--source/slang/slang-lower-to-ir.cpp15
-rw-r--r--source/slang/slang-syntax.cpp2
-rw-r--r--source/slang/slang-type-system-shared.h14
40 files changed, 326 insertions, 255 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 45b3435eb..c75d4735b 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -679,14 +679,14 @@ struct __none_t
{
};
-__generic<T>
+__generic<T, let addrSpace : uint64_t = $( (uint64_t)AddressSpace::UserPointer)ULL>
__magic_type(PtrType)
__intrinsic_type($(kIROp_PtrType))
struct Ptr
{
__generic<U>
__intrinsic_op($(kIROp_BitCast))
- __init(Ptr<U> ptr);
+ __init(Ptr<U, addrSpace> ptr);
__intrinsic_op($(kIROp_CastIntToPtr))
__init(uint64_t val);
@@ -703,53 +703,47 @@ struct Ptr
};
__intrinsic_op($(kIROp_Load))
-T __load<T>(Ptr<T> ptr);
+T __load<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr);
__intrinsic_op($(kIROp_Store))
-void __store<T>(Ptr<T> ptr, T val);
-
-__intrinsic_op($(kIROp_GetElementPtr))
-Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int index);
+void __store<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr, T val);
__intrinsic_op($(kIROp_GetElementPtr))
-Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int64_t index);
+Ptr<T, addrSpace> __getElementPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index);
__intrinsic_op($(kIROp_GetOffsetPtr))
-Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int index);
+Ptr<T, addrSpace> __getOffsetPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index);
-__intrinsic_op($(kIROp_GetOffsetPtr))
-Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int64_t index);
-
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_Less))
-bool operator<(Ptr<T> p1, Ptr<T> p2);
+bool operator <(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_Leq))
-bool operator<=(Ptr<T> p1, Ptr<T> p2);
+bool operator <=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_Greater))
-bool operator>(Ptr<T> p1, Ptr<T> p2);
+bool operator>(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_Geq))
-bool operator>=(Ptr<T> p1, Ptr<T> p2);
+bool operator >=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_Neq))
-bool operator!=(Ptr<T> p1, Ptr<T> p2);
+bool operator !=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_Eql))
-bool operator==(Ptr<T> p1, Ptr<T> p2);
+bool operator ==(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
extension bool : IRangedValue
{
- __generic<T>
+ __generic<T, let addrSpace : uint64_t>
__implicit_conversion($(kConversionCost_PtrToBool))
__intrinsic_op($(kIROp_CastPtrToBool))
- __init(Ptr<T> ptr);
+ __init(Ptr<T, addrSpace> ptr);
__generic<T : __EnumType>
__implicit_conversion($(kConversionCost_IntegerTruncate))
@@ -765,9 +759,9 @@ extension bool : IRangedValue
extension uint64_t : IRangedValue
{
- __generic<T>
+ __generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T> ptr);
+ __init(Ptr<T, addrSpace> ptr);
static const uint64_t maxValue = 0xFFFFFFFFFFFFFFFFULL;
static const uint64_t minValue = 0;
@@ -775,9 +769,9 @@ extension uint64_t : IRangedValue
extension int64_t : IRangedValue
{
- __generic<T>
+ __generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T> ptr);
+ __init(Ptr<T, addrSpace> ptr);
static const int64_t maxValue = 0x7FFFFFFFFFFFFFFFLL;
static const int64_t minValue = -0x8000000000000000LL;
@@ -785,9 +779,9 @@ extension int64_t : IRangedValue
extension intptr_t : IRangedValue
{
- __generic<T>
+ __generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T> ptr);
+ __init(Ptr<T, addrSpace> ptr);
static const intptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0x7FFFFFFFFFFFFFFFz":"0x7FFFFFFFz");
static const intptr_t minValue = $(SLANG_PROCESSOR_X86_64?"0x8000000000000000z":"0x80000000z");
static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4");
@@ -795,9 +789,9 @@ extension intptr_t : IRangedValue
extension uintptr_t : IRangedValue
{
- __generic<T>
+ __generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T> ptr);
+ __init(Ptr<T, addrSpace> ptr);
static const uintptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0xFFFFFFFFFFFFFFFFz":"0xFFFFFFFFz");
static const uintptr_t minValue = 0z;
static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4");
@@ -827,6 +821,8 @@ __intrinsic_type($(kIROp_ConstRefType))
struct ConstRef
{};
+typealias __Addr<T> = Ptr<T, $( (uint64_t)AddressSpace::Generic)ULL>;
+
__generic<T>
__magic_type(OptionalType)
__intrinsic_type($(kIROp_OptionalType))
@@ -969,10 +965,10 @@ extension Ptr<void>
[__unsafeForceInlineEarly]
__init(NativeString nativeStr) { this = nativeStr.getBuffer(); }
- __generic<T>
+ __generic<T, let addrSpace : uint64_t>
__intrinsic_op(0)
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
- __init(Ptr<T> ptr);
+ __init(Ptr<T, addrSpace> ptr);
__generic<T>
__intrinsic_op(0)
@@ -1628,21 +1624,28 @@ for (auto op : intrinsicUnaryOps)
}}}}
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op(0)
-__prefix Ref<T> operator*(Ptr<T> value);
+[require(cpp_cuda_spirv)]
+__prefix Ref<T> operator*(Ptr<T, addrSpace> value);
__generic<T>
__intrinsic_op(0)
-__prefix Ptr<T> operator&(__ref T value);
+[require(cpp_cuda_spirv)]
+__prefix Ptr<T, $( (uint64_t)AddressSpace::UserPointer)ULL> operator&(__ref T value);
__generic<T>
+__intrinsic_op(0)
+[require(cpp_cuda_spirv)]
+__Addr<T> __get_addr( __ref T value);
+
+__generic<T, let addrSpace : uint64_t>
__intrinsic_op($(kIROp_GetOffsetPtr))
-Ptr<T> operator+(Ptr<T> value, int64_t offset);
+Ptr<T, addrSpace> operator+(Ptr<T, addrSpace> value, int64_t offset);
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
[__unsafeForceInlineEarly]
-Ptr<T> operator-(Ptr<T> value, int64_t offset)
+Ptr<T, addrSpace> operator -(Ptr<T, addrSpace> value, int64_t offset)
{
return __getOffsetPtr(value, -offset);
}
@@ -1707,9 +1710,9 @@ matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
$(fixity.qual)
-__generic<T>
+__generic<T, let addrSpace : uint64_t>
[__unsafeForceInlineEarly]
-Ptr<T> operator$(op.name)(in out Ptr<T> value)
+Ptr<T, addrSpace> operator$(op.name)(in out Ptr<T, addrSpace> value)
{$(fixity.bodyPrefix) value = value $(op.binOp) 1; return $(fixity.returnVal); }
${{{{
@@ -2500,8 +2503,8 @@ __intrinsic_op($(kIROp_TreatAsDynamicUniform))
T asDynamicUniform<T>(T v);
__generic<T>
-__intrinsic_op($(kIROp_GetLegalizedSPIRVGlobalParamAddr))
-Ptr<T> __getLegalizedSPIRVGlobalParamAddr(T val);
+__intrinsic_op( $(kIROp_GetLegalizedSPIRVGlobalParamAddr))
+__Addr<T> __getLegalizedSPIRVGlobalParamAddr(T val);
__intrinsic_op($(kIROp_RequireComputeDerivative))
void __requireComputeDerivative();
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 1a112e1a9..0539097eb 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -17568,11 +17568,11 @@ Ref<T> __hitObjectAttributes<T>()
return t;
}
[ForceInline]
-Ptr<T> __allocHitObjectAttributes<T>()
+__Addr<T> __allocHitObjectAttributes<T>()
{
- [__vulkanHitObjectAttributes]
+ [__vulkanHitObjectAttributes]
static T t;
- return &t;
+ return __get_addr(t);
}
// Next is the custom intrinsic that will compute the hitObjectAttributes location
@@ -17840,7 +17840,7 @@ struct HitObject
case spirv:
{
// Save the attributes
- Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
+ __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
*attr = attributes;
@@ -17914,7 +17914,7 @@ struct HitObject
case spirv:
{
// Save the attributes
- Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
+ __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
*attr = attributes;
@@ -18004,7 +18004,7 @@ struct HitObject
case spirv:
{
// Save the attributes
- Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
+ __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
*attr = attributes;
let origin = Ray.Origin;
let direction = Ray.Direction;
@@ -18071,7 +18071,7 @@ struct HitObject
case spirv:
{
// Save the attributes
- Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
+ __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
*attr = attributes;
let origin = Ray.Origin;
let direction = Ray.Direction;
@@ -18618,7 +18618,7 @@ struct HitObject
}
case spirv:
{
- Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
+ __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
spirv_asm
{
OpExtension "SPV_NV_shader_invocation_reorder";
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index a672c1b7e..ce4c32c3a 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -290,9 +290,9 @@ Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const c
return rsType;
}
-PtrType* ASTBuilder::getPtrType(Type* valueType)
+PtrType* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace)
{
- return dynamicCast<PtrType>(getPtrType(valueType, "PtrType"));
+ return dynamicCast<PtrType>(getPtrType(valueType, addrSpace, "PtrType"));
}
// Construct the type `Out<valueType>`
@@ -306,9 +306,9 @@ InOutType* ASTBuilder::getInOutType(Type* valueType)
return dynamicCast<InOutType>(getPtrType(valueType, "InOutType"));
}
-RefType* ASTBuilder::getRefType(Type* valueType)
+RefType* ASTBuilder::getRefType(Type* valueType, AddressSpace addrSpace)
{
- return dynamicCast<RefType>(getPtrType(valueType, "RefType"));
+ return dynamicCast<RefType>(getPtrType(valueType, addrSpace, "RefType"));
}
ConstRefType* ASTBuilder::getConstRefType(Type* valueType)
@@ -327,6 +327,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
return as<PtrTypeBase>(getSpecializedBuiltinType(valueType, ptrTypeName));
}
+PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName)
+{
+ Val* args[] = { valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace) };
+ return as<PtrTypeBase>(getSpecializedBuiltinType(makeArrayView(args), ptrTypeName));
+}
+
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
{
if (!elementCount)
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 5b4ec5538..52858a6b1 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -6,7 +6,7 @@
#include "slang-ast-support-types.h"
#include "slang-ast-all.h"
-
+#include "slang-ir.h"
#include "../core/slang-type-traits.h"
#include "../core/slang-memory-arena.h"
@@ -439,7 +439,7 @@ public:
Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); }
// Construct the type `Ptr<valueType>`, where `Ptr`
// is looked up as a builtin type.
- PtrType* getPtrType(Type* valueType);
+ PtrType* getPtrType(Type* valueType, AddressSpace addrSpace);
// Construct the type `Out<valueType>`
OutType* getOutType(Type* valueType);
@@ -448,7 +448,7 @@ public:
InOutType* getInOutType(Type* valueType);
// Construct the type `Ref<valueType>`
- RefType* getRefType(Type* valueType);
+ RefType* getRefType(Type* valueType, AddressSpace addrSpace);
// Construct the type `ConstRef<valueType>`
ConstRefType* getConstRefType(Type* valueType);
@@ -459,6 +459,7 @@ public:
// Construct a pointer type like `Ptr<valueType>`, but where
// the actual type name for the pointer type is given by `ptrTypeName`
PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName);
+ PtrTypeBase* getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName);
ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount);
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 47cd68b9e..44585ee30 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -352,6 +352,66 @@ Type* NativeRefType::getValueType()
return as<Type>(_getGenericTypeArg(this, 0));
}
+Val* PtrTypeBase::getAddressSpace()
+{
+ return _getGenericTypeArg(this, 1);
+}
+
+AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal)
+{
+ AddressSpace addrSpace = AddressSpace::Generic;
+
+ if (auto cintVal = as<ConstantIntVal>(addrSpaceVal))
+ {
+ addrSpace = (AddressSpace)(cintVal->getValue());
+ }
+ return addrSpace;
+}
+
+void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace)
+{
+ switch (addrSpace)
+ {
+ case AddressSpace::Generic:
+ case AddressSpace::UserPointer:
+ break;
+ case AddressSpace::GroupShared:
+ out << toSlice(", groupshared");
+ break;
+ case AddressSpace::Global:
+ out << toSlice(", global");
+ break;
+ case AddressSpace::ThreadLocal:
+ out << toSlice(", threadlocal");
+ break;
+ case AddressSpace::Uniform:
+ out << toSlice(", uniform");
+ break;
+ default:
+ break;
+ }
+}
+
+void PtrType::_toTextOverride(StringBuilder& out)
+{
+ auto addrSpace = tryGetAddressSpaceValue(getAddressSpace());
+ if (addrSpace == AddressSpace::Generic)
+ out << toSlice("Addr<") << getValueType();
+ else
+ out << toSlice("Ptr<") << getValueType();
+ maybePrintAddrSpaceOperand(out, addrSpace);
+ out << toSlice(">");
+}
+
+void RefType::_toTextOverride(StringBuilder& out)
+{
+ out << toSlice("Ref<") << getValueType();
+ auto addressSpaceVal = getAddressSpace();
+ maybePrintAddrSpaceOperand(out, tryGetAddressSpaceValue(addressSpaceVal));
+ out << toSlice(">");
+}
+
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void NamedExpressionType::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 0d1a89860..945d051ba 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -539,6 +539,8 @@ class PtrTypeBase : public BuiltinType
// Get the type of the pointed-to value.
Type* getValueType();
+
+ Val* getAddressSpace();
};
class NoneType : public BuiltinType
@@ -555,6 +557,8 @@ class NullPtrType : public BuiltinType
class PtrType : public PtrTypeBase
{
SLANG_AST_CLASS(PtrType)
+
+ void _toTextOverride(StringBuilder& out);
};
// A GPU pointer type into global memory.
@@ -599,6 +603,7 @@ class RefTypeBase : public ParamDirectionType
class RefType : public RefTypeBase
{
SLANG_AST_CLASS(RefType)
+ void _toTextOverride(StringBuilder& out);
};
// The type for an `constref` parameter, e.g., `constref T`
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index b13571d18..59e79a07a 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -137,6 +137,7 @@ alias any_gfx_target = hlsl | metal | glsl | spirv;
alias any_cpp_target = cpp | cuda;
alias cpp_cuda = cpp | cuda;
+alias cpp_cuda_spirv = cpp | cuda | spirv;
alias cpp_cuda_glsl_spirv = cpp | cuda | glsl | spirv;
alias cpp_cuda_glsl_hlsl = cpp | cuda | glsl | hlsl;
alias cpp_cuda_glsl_hlsl_spirv = cpp | cuda | glsl | hlsl | spirv;
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 111f4e465..00625d5f4 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1024,7 +1024,6 @@ namespace Slang
return false;
if (as<RefType>(toType) && !fromExpr->type.isLeftValue)
return false;
-
ConversionCost subCost = kConversionCost_GetRef;
MakeRefExpr* refExpr = nullptr;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index fd03e5e87..96e0a95d0 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -4425,7 +4425,7 @@ namespace Slang
expr->base = CheckProperType(expr->base);
if (as<ErrorType>(expr->base.type))
expr->type = expr->base.type;
- auto ptrType = m_astBuilder->getPtrType(expr->base.type);
+ auto ptrType = m_astBuilder->getPtrType(expr->base.type, AddressSpace::UserPointer);
expr->type = m_astBuilder->getTypeType(ptrType);
return expr;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 5c508a541..bf478a17d 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -885,7 +885,7 @@ DIAGNOSTIC(99999, Internal, internalCompilerError, "Slang internal compiler erro
DIAGNOSTIC(99999, Error, compilationAborted, "Slang compilation aborted due to internal error")
DIAGNOSTIC(99999, Error, compilationAbortedDueToException, "Slang compilation aborted due to an exception of $0: $1")
DIAGNOSTIC(99999, Internal, serialDebugVerificationFailed, "Verification of serial debug information failed.")
-DIAGNOSTIC(99999, Internal, spirvValidationFailed, "Validation of generated SPIR-V failed.")
+DIAGNOSTIC(99999, Internal, spirvValidationFailed, "Validation of generated SPIR-V failed. SPIRV generated: \n$0")
DIAGNOSTIC(99999, Internal, noBlocksOrIntrinsic, "no blocks found for function definition, is there a '$0' intrinsic missing?")
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 7ce2c7900..dd5ef8c51 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1814,7 +1814,7 @@ void CLikeSourceEmitter::emitRateQualifiers(IRInst* value)
const auto rate = value->getRate();
if (rate)
{
- emitRateQualifiersAndAddressSpaceImpl(rate, -1);
+ emitRateQualifiersAndAddressSpaceImpl(rate, AddressSpace::Generic);
}
}
@@ -1822,8 +1822,8 @@ void CLikeSourceEmitter::emitRateQualifiersAndAddressSpace(IRInst* value)
{
const auto rate = value->getRate();
const auto ptrTy = composeGetters<IRPtrTypeBase>(value, &IRInst::getDataType);
- const auto addressSpace = ptrTy ? ptrTy->getAddressSpace() : -1;
- if (rate || addressSpace != -1)
+ const auto addressSpace = ptrTy ? ptrTy->getAddressSpace() : AddressSpace::Generic;
+ if (rate || addressSpace != AddressSpace::Generic)
{
emitRateQualifiersAndAddressSpaceImpl(rate, addressSpace);
}
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 07cb4a0bc..c866456a2 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -488,7 +488,7 @@ public:
virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); };
- virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); }
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); }
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) { SLANG_UNUSED(inst); SLANG_UNUSED(allowOffsetLayout); }
virtual void emitSimpleFuncParamImpl(IRParam* param);
virtual void emitSimpleFuncParamsImpl(IRFunc* func);
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index f4b45b7aa..ab7ee8f58 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -738,7 +738,7 @@ void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}
-void CUDASourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
+void CUDASourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] AddressSpace addressSpace)
{
if (as<IRGroupSharedRate>(rate))
{
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index 5d76c4ef3..ac7a33302 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -67,7 +67,7 @@ protected:
virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE;
virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
- virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE;
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) SLANG_OVERRIDE;
virtual void emitSimpleFuncImpl(IRFunc* func) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamsImpl(IRFunc* func) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index eea672938..0fb868342 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -2638,9 +2638,9 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled type");
}
-void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace)
+void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace)
{
- if(addressSpace == SpvStorageClassTaskPayloadWorkgroupEXT)
+ if(addressSpace == (AddressSpace)SpvStorageClassTaskPayloadWorkgroupEXT)
{
m_writer->emit("taskPayloadSharedEXT ");
}
diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h
index efd3ded75..1911749ec 100644
--- a/source/slang/slang-emit-glsl.h
+++ b/source/slang/slang-emit-glsl.h
@@ -33,7 +33,7 @@ protected:
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
- virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE;
virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE;
virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 4d3830969..d3757f534 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -1096,7 +1096,7 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}
-void HLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
+void HLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] AddressSpace addressSpace)
{
if (as<IRGroupSharedRate>(rate))
{
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index abc673a3d..2137a6d3d 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -35,7 +35,7 @@ protected:
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
- virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE;
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index e7df29e0c..302e2b6b7 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -984,7 +984,7 @@ void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTy
// We emit packoffset as a semantic in `emitSemantic`, so nothing to do here.
}
-void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace)
+void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace)
{
if (as<IRGroupSharedRate>(rate))
{
@@ -992,7 +992,7 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRI
return;
}
- switch ((AddressSpace)addressSpace)
+ switch (addressSpace)
{
case AddressSpace::GroupShared:
m_writer->emit("threadgroup ");
diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h
index 32557bf27..6460b076a 100644
--- a/source/slang/slang-emit-metal.h
+++ b/source/slang/slang-emit-metal.h
@@ -31,7 +31,7 @@ protected:
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
- virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE;
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
virtual void emitPostDeclarationAttributesForType(IRInst* type) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index f5bb21c00..cb9eb0ae9 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1108,7 +1108,17 @@ struct SPIRVEmitContext
// If we have seen this before, return the memoized instruction
if (SpvInst** memoized = m_spvTypeInsts.tryGetValue(key))
+ {
+ // There could be another different slang IR inst that translates to
+ // the same spir-v inst.
+ // For example, both Ptr<T> and Ref<T> translates to the same pointer
+ // type in spirv.
+ // In this case we need to make sure we also
+ // register `inst` to map it to the memoized spir-v inst.
+ if (irInst)
+ m_mapIRInstToSpvInst.addIfNotExists(irInst, *memoized);
return *memoized;
+ }
// Otherwise, we can construct our instruction and record the result
InstConstructScope scopeInst(this, opcode, irInst);
@@ -1213,6 +1223,19 @@ struct SPIRVEmitContext
return m_NonSemanticDebugPrintfExtInst;
}
+ SpvStorageClass addressSpaceToStorageClass(AddressSpace addrSpace)
+ {
+ switch (addrSpace)
+ {
+ case AddressSpace::Generic:
+ return SpvStorageClassMax;
+ case AddressSpace::UserPointer:
+ return SpvStorageClassPhysicalStorageBuffer;
+ default:
+ return (SpvStorageClass)addrSpace;
+ }
+ }
+
// Now that we've gotten the core infrastructure out of the way,
// let's start looking at emitting some instructions that make
// up a SPIR-V module.
@@ -1398,7 +1421,7 @@ struct SPIRVEmitContext
auto ptrType = as<IRPtrTypeBase>(inst);
SLANG_ASSERT(ptrType);
if (ptrType->hasAddressSpace())
- storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
if (storageClass == SpvStorageClassStorageBuffer)
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class"));
if (storageClass == SpvStorageClassPhysicalStorageBuffer)
@@ -1411,14 +1434,24 @@ struct SPIRVEmitContext
&& as<IRStructType>(valueType)
&& storageClass == SpvStorageClassPhysicalStorageBuffer);
SpvId valueTypeId;
- if (useForwardDeclaration)
+ if (as<IRVoidType>(valueType))
{
- valueTypeId = getIRInstSpvID(valueType);
+ // Emit void* as uint*.
+ IRBuilder builder(valueType);
+ builder.setInsertBefore(valueType);
+ valueTypeId = getID(ensureInst(builder.getUIntType()));
}
else
{
- auto spvValueType = ensureInst(valueType);
- valueTypeId = getID(spvValueType);
+ if (useForwardDeclaration)
+ {
+ valueTypeId = getIRInstSpvID(valueType);
+ }
+ else
+ {
+ auto spvValueType = ensureInst(valueType);
+ valueTypeId = getID(spvValueType);
+ }
}
auto resultSpvType = emitOpTypePointer(
@@ -3339,7 +3372,7 @@ struct SPIRVEmitContext
if (auto ptrType = as<IRPtrTypeBase>(globalInst->getDataType()))
{
auto addrSpace = ptrType->getAddressSpace();
- if (addrSpace != SpvStorageClassInput && addrSpace != SpvStorageClassOutput)
+ if (addrSpace != AddressSpace(SpvStorageClassInput) && addrSpace != AddressSpace(SpvStorageClassOutput))
continue;
}
}
@@ -4042,7 +4075,7 @@ struct SPIRVEmitContext
if (!ptrType)
return;
auto addrSpace = ptrType->getAddressSpace();
- if (addrSpace == SpvStorageClassInput)
+ if (addrSpace == AddressSpace(SpvStorageClassInput))
{
if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
{
@@ -4333,10 +4366,10 @@ struct SPIRVEmitContext
void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst)
{
- auto ptrType = as<IRPtrType>(inst->getDataType());
+ auto ptrType = as<IRPtrType>(unwrapArray(inst->getDataType()));
if (!ptrType)
return;
- if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ if (addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer)
{
// If inst has a pointer type with PhysicalStorageBuffer address space,
// emit AliasedPointer decoration.
@@ -4351,10 +4384,10 @@ struct SPIRVEmitContext
{
// If the pointee type is a pointer with StorageBuffer address space,
// we also want to emit AliasedPointer decoration.
- ptrType = as<IRPtrType>(ptrType->getValueType());
+ ptrType = as<IRPtrType>(unwrapArray(ptrType->getValueType()));
if (!ptrType)
return;
- if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ if (addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer)
{
emitOpDecorate(
getSection(SpvLogicalSectionID::Annotations),
@@ -4996,7 +5029,7 @@ struct SPIRVEmitContext
SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst)
{
auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType());
- if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ if (ptrType && addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer)
{
IRSizeAndAlignment sizeAndAlignment;
getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), ptrType->getValueType(), &sizeAndAlignment);
@@ -5011,7 +5044,7 @@ struct SPIRVEmitContext
SpvInst* emitStore(SpvInstParent* parent, IRStore* inst)
{
auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType());
- if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ if (ptrType && addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer)
{
IRSizeAndAlignment sizeAndAlignment;
getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), ptrType->getValueType(), &sizeAndAlignment);
@@ -5077,7 +5110,7 @@ struct SPIRVEmitContext
return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView());
}
- IRPtrTypeBase* getPtrTypeWithAddressSpace(IRPtrTypeBase* ptrTypeWithNoAddressSpace, IRIntegerValue addressSpace)
+ IRPtrTypeBase* getPtrTypeWithAddressSpace(IRPtrTypeBase* ptrTypeWithNoAddressSpace, AddressSpace addressSpace)
{
// If it's already ok, return as is
if(ptrTypeWithNoAddressSpace->getAddressSpace() == addressSpace)
@@ -5104,7 +5137,7 @@ struct SPIRVEmitContext
parent,
inst,
// Make sure the resulting pointer has the correct storage class
- getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), storageClass),
+ getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), AddressSpace(storageClass)),
inst->getOperand(0),
makeArray(emitIntConstant(0, builder.getIntType()), ensureInst(inst->getOperand(1)))
);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 9e21ccbfd..03d1b932c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1718,9 +1718,13 @@ SlangResult emitSPIRVForEntryPointsDirectly(
{
if (SLANG_FAILED(compiler->validate((uint32_t*)spirv.getBuffer(), int(spirv.getCount()/4))))
{
+ String err;
+ String dis;
+ disassembleSPIRV(spirv, err, dis);
codeGenContext->getSink()->diagnoseWithoutSourceView(
SourceLoc{},
- Diagnostics::spirvValidationFailed
+ Diagnostics::spirvValidationFailed,
+ dis
);
}
}
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index 13059a84a..f09294aa9 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -145,13 +145,10 @@ namespace Slang
case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(dataType);
- auto elementType = vectorType->getElementType();
auto elementCount = getIntVal(vectorType->getElementCount());
- auto elementPtrType = builder->getPtrType(elementType);
for (IRIntegerValue i = 0; i < elementCount; i++)
{
auto elementAddr = builder->emitElementAddress(
- elementPtrType,
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
emitMarshallingCode(builder, context, elementAddr);
@@ -161,20 +158,16 @@ namespace Slang
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(dataType);
- auto elementType = matrixType->getElementType();
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
- auto rowVecType = builder->getVectorType(elementType, matrixType->getRowCount());
for (IRIntegerValue i = 0; i < colCount; i++)
{
auto col = builder->emitElementAddress(
- builder->getPtrType(rowVecType),
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
for (IRIntegerValue j = 0; j < rowCount; j++)
{
auto element = builder->emitElementAddress(
- builder->getPtrType(elementType),
col,
builder->getIntValue(builder->getIntType(), j));
emitMarshallingCode(builder, context, element);
@@ -198,11 +191,9 @@ namespace Slang
case kIROp_ArrayType:
{
auto arrayType = cast<IRArrayType>(dataType);
- auto elementPtrType = builder->getPtrType(arrayType->getElementType());
for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
{
auto fieldAddr = builder->emitElementAddress(
- elementPtrType,
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
emitMarshallingCode(builder, context, fieldAddr);
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 511055713..9fe4ec70b 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1117,14 +1117,9 @@ IRInst* emitIndexedStoreAddressForVar(
const List<IndexTrackingInfo>& defBlockIndices)
{
IRInst* storeAddr = localVar;
- IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType();
-
for (auto& index : defBlockIndices)
{
- currType = as<IRArrayType>(currType)->getElementType();
-
storeAddr = builder->emitElementAddress(
- builder->getPtrType(currType),
storeAddr,
index.primalCountParam);
}
@@ -1141,11 +1136,9 @@ IRInst* emitIndexedLoadAddressForVar(
const List<IndexTrackingInfo>& useBlockIndices)
{
IRInst* loadAddr = localVar;
- IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType();
for (auto index : defBlockIndices)
{
- currType = as<IRArrayType>(currType)->getElementType();
if (useBlockIndices.contains(index))
{
// If the use-block is under the same region, use the
@@ -1154,7 +1147,6 @@ IRInst* emitIndexedLoadAddressForVar(
auto diffCounterCurrValue = index.diffCountParam;
loadAddr = builder->emitElementAddress(
- builder->getPtrType(currType),
loadAddr,
diffCounterCurrValue);
}
@@ -1173,7 +1165,6 @@ IRInst* emitIndexedLoadAddressForVar(
builder->getIntValue(builder->getIntType(), 1));
loadAddr = builder->emitElementAddress(
- builder->getPtrType(currType),
loadAddr,
primalCounterLastValue);
}
diff --git a/source/slang/slang-ir-composite-reg-to-mem.cpp b/source/slang/slang-ir-composite-reg-to-mem.cpp
index 243a0e2b0..512683938 100644
--- a/source/slang/slang-ir-composite-reg-to-mem.cpp
+++ b/source/slang/slang-ir-composite-reg-to-mem.cpp
@@ -36,7 +36,6 @@ namespace Slang
if (getElementUser->getOperands() == use)
{
newAddr = builder.emitElementAddress(
- builder.getPtrType(user->getFullType()),
addr,
getElementUser->getIndex());
}
diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp
index 9d3f09712..56fd62883 100644
--- a/source/slang/slang-ir-explicit-global-context.cpp
+++ b/source/slang/slang-ir-explicit-global-context.cpp
@@ -258,7 +258,7 @@ struct IntroduceExplicitGlobalContextPass
if (kind == GlobalObjectKind::GlobalVar)
{
auto ptrType = as<IRPtrTypeBase>(type);
- if (ptrType->getAddressSpace() == (IRIntegerValue)AddressSpace::GroupShared)
+ if (ptrType->getAddressSpace() == AddressSpace::GroupShared)
{
fieldDataType = ptrType;
needDereference = true;
diff --git a/source/slang/slang-ir-glsl-liveness.cpp b/source/slang/slang-ir-glsl-liveness.cpp
index af64df4f4..64d41490a 100644
--- a/source/slang/slang-ir-glsl-liveness.cpp
+++ b/source/slang/slang-ir-glsl-liveness.cpp
@@ -133,7 +133,7 @@ void GLSLLivenessContext::_replaceMarker(IRLiveRangeMarker* markerInst)
IRType* paramTypes[] =
{
- m_builder.getRefType(referencedType), ///< Use a reference to the referenced type
+ m_builder.getRefType(referencedType, AddressSpace::Generic), ///< Use a reference to the referenced type
m_spirvIntLiteralType, ///< The size type
};
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a5493d6ea..3aa7d1f64 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3490,10 +3490,11 @@ public:
IRPtrType* getPtrType(IRType* valueType);
IROutType* getOutType(IRType* valueType);
IRInOutType* getInOutType(IRType* valueType);
- IRRefType* getRefType(IRType* valueType);
+ IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
IRConstRefType* getConstRefType(IRType* valueType);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace);
+ IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace);
IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); }
IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); }
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 5eefc121e..cc24a0e81 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -578,7 +578,7 @@ namespace Slang
{
if (auto ptrType = as<IRPtrType>(globalInst))
{
- if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ if (ptrType->getAddressSpace() == AddressSpace::UserPointer)
elementType = ptrType->getValueType();
}
}
diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp
index 7fd609011..953d9f68a 100644
--- a/source/slang/slang-ir-simplify-for-emit.cpp
+++ b/source/slang/slang-ir-simplify-for-emit.cpp
@@ -73,7 +73,6 @@ struct SimplifyForEmitContext : public InstPassBase
for (UInt i = 0; i < makeArray->getOperandCount(); i++)
{
auto elementAddr = builder.emitElementAddress(
- builder.getPtrType(arrayType->getElementType()),
store->getPtr(),
builder.getIntValue(builder.getIntType(), (IRIntegerValue)i));
builder.emitStore(elementAddr, makeArray->getOperand(i));
@@ -107,7 +106,6 @@ struct SimplifyForEmitContext : public InstPassBase
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
auto elementAddr = builder.emitElementAddress(
- builder.getPtrType(arrayType->getElementType()),
store->getPtr(),
builder.getIntValue(builder.getIntType(), i));
builder.emitStore(elementAddr, makeArray->getOperand(0));
diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp
index 1d899e240..24eee3d76 100644
--- a/source/slang/slang-ir-specialize-address-space.cpp
+++ b/source/slang/slang-ir-specialize-address-space.cpp
@@ -331,9 +331,12 @@ namespace Slang
auto ptrType = as<IRPtrTypeBase>(inst->getDataType());
if (ptrType)
{
- IRBuilder builder(inst);
- auto newType = builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace);
- setDataType(inst, newType);
+ if (ptrType->getAddressSpace() != addrSpace)
+ {
+ IRBuilder builder(inst);
+ auto newType = builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace);
+ setDataType(inst, newType);
+ }
}
}
}
diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h
index 300b6129c..61b191cf2 100644
--- a/source/slang/slang-ir-specialize-address-space.h
+++ b/source/slang/slang-ir-specialize-address-space.h
@@ -1,11 +1,13 @@
// slang-ir-specialize-address-space.h
#pragma once
+#include <cinttypes>
+
namespace Slang
{
struct IRModule;
struct IRInst;
- enum class AddressSpace;
+ enum class AddressSpace : uint64_t;
struct AddressSpaceSpecializationContext
{
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 5713b9639..5c9e1ad24 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -1996,9 +1996,6 @@ struct SpecializationContext
auto index = inst->getIndex();
auto val = wrapInst->getWrappedValue();
- auto ptrType = cast<IRPtrTypeBase>(val->getDataType());
- auto arrayType = cast<IRArrayTypeBase>(ptrType->getValueType());
- auto elementType = arrayType->getElementType();
auto resultType = inst->getFullType();
@@ -2013,8 +2010,7 @@ struct SpecializationContext
slotOperands.add(wrapInst->getSlotOperand(ii));
}
- auto elementPtrType = builder.getPtrType(ptrType->getOp(), elementType);
- auto newElementAddr = builder.emitElementAddress(elementPtrType, val, index);
+ auto newElementAddr = builder.emitElementAddress(val, index);
auto newWrapExistentialInst = builder.emitWrapExistential(
resultType, newElementAddr, slotOperandCount, slotOperands.getArrayView().getBuffer());
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 27a186ee9..528fe2331 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -783,7 +783,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newPtrType = builder.getPtrType(
- oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassFunction);
+ oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
}
@@ -800,12 +800,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return;
if (!oldPtrType->hasAddressSpace())
{
- SpvStorageClass addressSpace = (SpvStorageClass)-1;
+ AddressSpace addressSpace = AddressSpace::Generic;
if (block == func->getFirstBlock())
{
// A pointer typed function parameter should always be in the storage buffer address space.
- addressSpace = SpvStorageClassPhysicalStorageBuffer;
+ addressSpace = AddressSpace::UserPointer;
}
else
{
@@ -816,19 +816,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto argPtrType = as<IRPtrType>(arg->getDataType());
if (argPtrType->hasAddressSpace())
{
- if (addressSpace == (SpvStorageClass)-1)
- addressSpace = (SpvStorageClass)argPtrType->getAddressSpace();
+ if (addressSpace == AddressSpace::Generic)
+ addressSpace = argPtrType->getAddressSpace();
else if (addressSpace != argPtrType->getAddressSpace())
m_sharedContext->m_sink->diagnose(inst, Diagnostics::inconsistentPointerAddressSpace, inst);
}
}
}
- if (addressSpace != (SpvStorageClass)-1)
+ if (addressSpace != AddressSpace::Generic)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newPtrType = builder.getPtrType(
- oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassPhysicalStorageBuffer);
+ oldPtrType->getOp(), oldPtrType->getValueType(), AddressSpace::UserPointer);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
}
@@ -842,7 +842,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return;
// Update the pointer value type with storage-buffer-address-space-decorated types.
- auto newPtrValueType = translateToStorageBufferPointer(oldPtrType->getValueType());
+ auto newPtrValueType = oldPtrType->getValueType();
if (newPtrValueType != oldPtrType->getValueType())
{
IRBuilder builder(inst);
@@ -900,7 +900,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto newPtrType =
- builder.getPtrType(oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), storageClass);
+ builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
return;
@@ -923,7 +923,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto qualPtrType = builder.getPtrType(
- ptrType->getOp(), translateToStorageBufferPointer(ptrType->getValueType()), snippet->resultStorageClass);
+ ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass);
List<IRInst*> args;
for (UInt i = 0; i < inst->getArgCount(); i++)
args.add(inst->getArg(i));
@@ -1011,7 +1011,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
// If we reach here, we need to allocate a temp var.
- auto tempVar = builder.emitVar(translateToStorageBufferPointer(ptrType->getValueType()));
+ auto tempVar = builder.emitVar(ptrType->getValueType());
auto load = builder.emitLoad(arg);
builder.emitStore(tempVar, load);
newArgs.add(tempVar);
@@ -1021,7 +1021,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (writeBacks.getCount())
{
auto newCall = builder.emitCallInst(
- translateToStorageBufferPointer(inst->getFullType()),
+ inst->getFullType(),
inst->getCallee(),
newArgs);
for (auto wb : writeBacks)
@@ -1033,15 +1033,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
inst->removeAndDeallocate();
addUsersToWorkList(newCall);
}
- else
- {
- // If we reach here, we have determined that all arguments passed as a pointer
- // are actual memory objects, so they can be passed in as-is.
- // We still need to make sure the callee is specialized to the address-space
- // of the arguments, this is done in a separate specialization pass.
-
- translatePtrResultType(inst);
- }
}
Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar;
@@ -1074,7 +1065,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(inst);
else
setInsertAfterOrdinaryInst(&builder, x);
- y = builder.emitVar(translateToStorageBufferPointer(x->getDataType()), SpvStorageClassFunction);
+ y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
builder.emitStore(y, x);
if (x->getParent()->getOp() != kIROp_Module)
m_mapArrayValueToVar.set(x, y);
@@ -1101,7 +1092,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(gepInst);
auto newPtrType = builder.getPtrType(
oldResultType->getOp(),
- translateToStorageBufferPointer(oldResultType->getValueType()),
+ oldResultType->getValueType(),
ptrType->getAddressSpace());
IRInst* args[2] = { base, index };
auto newInst =
@@ -1154,7 +1145,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(offsetPtrInst);
builder.setInsertBefore(offsetPtrInst);
auto newResultType = builder.getPtrType(resultPtrType->getOp(),
- translateToStorageBufferPointer(resultPtrType->getValueType()),
+ resultPtrType->getValueType(),
ptrOperandType->getAddressSpace());
auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType);
addUsersToWorkList(newInst);
@@ -1174,7 +1165,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(loadInst);
IRInst* args[] = { sb, index };
auto addrInst = builder.emitIntrinsicInst(
- builder.getPtrType(kIROp_PtrType, translateToStorageBufferPointer(loadInst->getFullType()), getStorageBufferStorageClass()),
+ builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferStorageClass()),
kIROp_RWStructuredBufferGetElementPtr,
2,
args);
@@ -1357,7 +1348,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return;
auto oldResultType = as<IRPtrTypeBase>(inst->getDataType());
auto oldValueType = oldResultType->getValueType();
- auto newValueType = translateToStorageBufferPointer(oldValueType);
+ auto newValueType = oldValueType;
if (oldValueType != newValueType || oldResultType->getAddressSpace() != ptrType->getAddressSpace())
{
@@ -1381,7 +1372,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto ptrType = as<IRPtrType>(inst->getDataType());
if (!ptrType)
return;
- auto newPtrType = translateToStorageBufferPointer(ptrType);
+ auto newPtrType = ptrType;
if (newPtrType == ptrType)
return;
IRBuilder builder(inst);
@@ -1890,79 +1881,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
addToWorkList(branch->getOperand(0));
}
- // If type is pointer type and does not have an address space, make it a
- // storage buffer pointer.
- IRType* translateToStorageBufferPointer(IRType* type)
- {
- if (auto ptrType = as<IRPtrType>(type))
- {
- auto oldValueType = ptrType->getValueType();
- auto newValueType = translateToStorageBufferPointer(oldValueType);
- if (oldValueType != newValueType || !ptrType->hasAddressSpace())
- {
- IRBuilder builder(m_module);
- IRIntegerValue addressSpace = (ptrType->hasAddressSpace() ? ptrType->getAddressSpace() : IRIntegerValue(SpvStorageClassPhysicalStorageBuffer));
- return builder.getPtrType(ptrType->getOp(), newValueType, addressSpace);
- }
- return ptrType;
- }
- else if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
- {
- auto oldValueType = arrayTypeBase->getElementType();
- auto newValueType = translateToStorageBufferPointer(oldValueType);
- if (oldValueType != newValueType)
- {
- IRBuilder builder(m_module);
- return builder.getArrayTypeBase(arrayTypeBase->getOp(), newValueType, arrayTypeBase->getElementCount());
- }
- return arrayTypeBase;
- }
- return type;
- }
-
- void translatePtrResultType(IRInst* inst)
- {
- auto ptrType = as<IRPtrType>(inst->getDataType());
- if (!ptrType)
- {
- if (auto refType = as<IRRefType>(inst->getDataType()))
- {
- // Functions that return ref type should be treated as returning a pointer.
- IRBuilder builder(inst);
- ptrType = builder.getPtrType(refType->getValueType());
- }
- }
- auto newPtrType = translateToStorageBufferPointer(ptrType);
- if (newPtrType == ptrType)
- return;
- IRBuilder builder(inst);
- auto newInst = builder.replaceOperand(&inst->typeUse, newPtrType);
- addUsersToWorkList(newInst);
- }
-
void processPtrLit(IRInst* inst)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
- auto newPtrType = translateToStorageBufferPointer(as<IRPtrType>(inst->getFullType()));
+ auto newPtrType = as<IRPtrType>(inst->getFullType());
auto newInst = builder.emitCastIntToPtr(newPtrType, builder.getIntValue(builder.getUInt64Type(), 0));
inst->replaceUsesWith(newInst);
addUsersToWorkList(newInst);
}
- void processPtrCast(IRInst* cast)
- {
- translatePtrResultType(cast);
- }
-
- void processLoad(IRInst* inst)
- {
- translatePtrResultType(inst);
- }
-
void processStructField(IRStructField* field)
{
- auto newFieldType = translateToStorageBufferPointer(field->getFieldType());
+ auto newFieldType = field->getFieldType();
if (newFieldType != field->getFieldType())
field->setFieldType(newFieldType);
}
@@ -2095,17 +2026,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_MakeOptionalNone:
processConstructor(inst);
break;
- case kIROp_BitCast:
- case kIROp_PtrCast:
- case kIROp_CastIntToPtr:
- processPtrCast(inst);
- break;
case kIROp_PtrLit:
processPtrLit(inst);
break;
- case kIROp_Load:
- processLoad(inst);
- break;
case kIROp_unconditionalBranch:
processBranch(inst);
break;
@@ -2261,7 +2184,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
{
// Don't assign address space to additional insts, since we should have
// already assigned address space to them in earlier stages of legalization.
- auto type = unwrapAttributedType(inst->getDataType());
+ auto type = inst->getDataType();
+ for (;;)
+ {
+ auto newType = (IRType*)unwrapAttributedType(type);
+ newType = unwrapArray(newType);
+ if (newType == type) break;
+ type = newType;
+ }
if (!type)
return AddressSpace::Generic;
return getAddressSpaceFromVarType(type);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ba7376c46..bb05a1c29 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2813,9 +2813,9 @@ namespace Slang
return (IRInOutType*) getPtrType(kIROp_InOutType, valueType);
}
- IRRefType* IRBuilder::getRefType(IRType* valueType)
+ IRRefType* IRBuilder::getRefType(IRType* valueType, AddressSpace addrSpace)
{
- return (IRRefType*) getPtrType(kIROp_RefType, valueType);
+ return (IRRefType*) getPtrType(kIROp_RefType, valueType, addrSpace);
}
IRConstRefType* IRBuilder::getConstRefType(IRType* valueType)
@@ -2840,8 +2840,13 @@ namespace Slang
IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace)
{
- IRInst* operands[] = {valueType, getIntValue(getIntType(), addressSpace)};
- return (IRPtrType*)getType(op, 2, operands);
+ return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), addressSpace));
+ }
+
+ IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRInst* addressSpace)
+ {
+ IRInst* operands[] = { valueType, addressSpace };
+ return (IRPtrType*)getType(op, addressSpace ? 2 : 1, operands);
}
IRTextureTypeBase* IRBuilder::getTextureType(IRType* elementType, IRInst* shape, IRInst* isArray, IRInst* isMS, IRInst* sampleCount, IRInst* access, IRInst* isShadow, IRInst* isCombined, IRInst* format)
@@ -4881,11 +4886,30 @@ namespace Slang
return inst;
}
+ IRType* maybePropagateAddressSpace(IRBuilder* builder, IRInst* basePtr, IRType* type)
+ {
+ if (auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType()))
+ {
+ if (auto resultPtrType = as<IRPtrTypeBase>(type))
+ {
+ if (basePtrType->getAddressSpace() != resultPtrType->getAddressSpace())
+ {
+ type = builder->getPtrType(
+ resultPtrType->getOp(), resultPtrType->getValueType(), basePtrType->getAddressSpace());
+ }
+ }
+ }
+ return type;
+ }
+
IRInst* IRBuilder::emitFieldAddress(
IRType* type,
IRInst* base,
IRInst* field)
{
+ // Propagate pointer address space if it is available on base.
+ type = maybePropagateAddressSpace(this, base, type);
+
auto inst = createInst<IRFieldAddress>(
this,
kIROp_FieldAddress,
@@ -4982,6 +5006,9 @@ namespace Slang
IRInst* basePtr,
IRInst* index)
{
+ // Propagate pointer address space if it is available on base.
+ type = maybePropagateAddressSpace(this, basePtr, type);
+
auto inst = createInst<IRFieldAddress>(
this,
kIROp_GetElementPtr,
@@ -5004,9 +5031,20 @@ namespace Slang
IRInst* basePtr,
IRInst* index)
{
+ AddressSpace addrSpace = AddressSpace::Generic;
+ IRInst* valueType = nullptr;
+ auto basePtrType = unwrapAttributedType(basePtr->getDataType());
+ if (auto ptrType = as<IRPtrTypeBase>(basePtrType))
+ {
+ addrSpace = ptrType->getAddressSpace();
+ valueType = ptrType->getValueType();
+ }
+ else if (auto ptrLikeType = as<IRPointerLikeType>(basePtrType))
+ {
+ valueType = ptrLikeType->getElementType();
+ }
IRType* type = nullptr;
- auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType());
- auto valueType = unwrapAttributedType(basePtrType->getValueType());
+ valueType = unwrapAttributedType(valueType);
if (auto arrayType = as<IRArrayTypeBase>(valueType))
{
type = arrayType->getElementType();
@@ -5028,7 +5066,7 @@ namespace Slang
auto inst = createInst<IRGetElementPtr>(
this,
kIROp_GetElementPtr,
- getPtrType(type),
+ getPtrType(kIROp_PtrType, type, addrSpace),
basePtr,
index);
@@ -5058,7 +5096,7 @@ namespace Slang
}
}
SLANG_RELEASE_ASSERT(resultType);
- basePtr = emitFieldAddress(getPtrType(resultType), basePtr, structKey);
+ basePtr = emitFieldAddress(getPtrType(kIROp_PtrType, resultType, basePtrType->getAddressSpace()), basePtr, structKey);
}
else
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 9a773a891..cb8e83df8 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -39,17 +39,6 @@ struct IRModule;
struct IRStructField;
struct IRStructKey;
-enum class AddressSpace
-{
- Generic = 0x7fffffff,
- ThreadLocal = 1,
- Global = 2,
- GroupShared = 3,
- Uniform = 4,
- // specific address space for payload data in metal
- MetalObjectData = 5,
-};
-
typedef unsigned int IROpFlags;
enum : IROpFlags
{
@@ -1710,11 +1699,11 @@ struct IRPtrTypeBase : IRType
{
IRType* getValueType() { return (IRType*)getOperand(0); }
- bool hasAddressSpace() { return getOperandCount() > 1; }
+ bool hasAddressSpace() { return getOperandCount() > 1 && getAddressSpace() != AddressSpace::Generic; }
- IRIntegerValue getAddressSpace()
+ AddressSpace getAddressSpace()
{
- return getOperandCount() > 1 ? static_cast<IRIntLit*>(getOperand(1))->getValue() : -1;
+ return getOperandCount() > 1 ? (AddressSpace)static_cast<IRIntLit*>(getOperand(1))->getValue() : AddressSpace::Generic;
}
IR_PARENT_ISA(PtrTypeBase)
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index 393b34e68..66c0044b6 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -209,8 +209,8 @@ bool isPointerToResourceType(IRType* type)
{
while (auto ptrType = as<IRPtrTypeBase>(type))
{
- if (ptrType->getAddressSpace() == SpvStorageClassStorageBuffer ||
- ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBufferEXT)
+ if (ptrType->getAddressSpace() == AddressSpace(SpvStorageClassStorageBuffer) ||
+ ptrType->getAddressSpace() == AddressSpace::UserPointer)
return true;
type = ptrType->getValueType();
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d3770753c..74aa0a0ee 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1855,8 +1855,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto astValueType = type->getValueType();
IRType* irValueType = lowerType(context, astValueType);
-
- return getBuilder()->getPtrType(irValueType);
+ IRInst* addrSpace = nullptr;
+ if (auto astAddrSpace = type->getAddressSpace())
+ {
+ addrSpace = getSimpleVal(context, lowerVal(context, astAddrSpace));
+ }
+ else
+ {
+ addrSpace = getBuilder()->getIntValue(getBuilder()->getUInt64Type(), (IRIntegerValue)AddressSpace::Generic);
+ }
+ return getBuilder()->getPtrType(kIROp_PtrType, irValueType, addrSpace);
}
IRType* visitDeclRefType(DeclRefType* type)
@@ -3138,7 +3146,7 @@ void _lowerFuncDeclBaseTypeInfo(
irParamType = builder->getInOutType(irParamType);
break;
case kParameterDirection_Ref:
- irParamType = builder->getRefType(irParamType);
+ irParamType = builder->getRefType(irParamType, AddressSpace::Generic);
break;
case kParameterDirection_ConstRef:
irParamType = builder->getConstRefType(irParamType);
@@ -4972,7 +4980,6 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
builder->emitElementAddress(
- context->irBuilder->getPtrType(type),
baseVal.val,
indexVal));
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 0b92d07af..d6efb47b3 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -706,7 +706,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
auto paramType = getParamType(astBuilder, paramDeclRef);
if( paramDecl->findModifier<RefModifier>() )
{
- paramType = astBuilder->getRefType(paramType);
+ paramType = astBuilder->getRefType(paramType, AddressSpace::Generic);
}
else if (paramDecl->findModifier<ConstRefModifier>())
{
diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h
index 2f467a05a..404c84cf4 100644
--- a/source/slang/slang-type-system-shared.h
+++ b/source/slang/slang-type-system-shared.h
@@ -58,6 +58,20 @@ FOREACH_BASE_TYPE(DEFINE_BASE_TYPE)
const int kStdlibTextureIsShadowParameterIndex = 6;
const int kStdlibTextureIsCombinedParameterIndex = 7;
const int kStdlibTextureFormatParameterIndex = 8;
+
+ enum class AddressSpace : uint64_t
+ {
+ Generic = 0x7fffffff,
+ ThreadLocal = 1,
+ Global = 2,
+ GroupShared = 3,
+ Uniform = 4,
+ // specific address space for payload data in metal
+ MetalObjectData = 5,
+
+ // Default address space for a user-defined pointer
+ UserPointer = 0x100000001ULL,
+ };
}
#endif