summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cuda-prelude.h2
-rw-r--r--source/slang/hlsl.meta.slang19
-rw-r--r--source/slang/slang-emit-base.cpp7
-rw-r--r--source/slang/slang-emit-spirv-ops.h4
-rw-r--r--source/slang/slang-emit-spirv.cpp462
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp35
-rw-r--r--source/slang/slang-ir-bind-existentials.cpp3
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp18
-rw-r--r--source/slang/slang-ir-inline.cpp4
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-layout.cpp5
-rw-r--r--source/slang/slang-ir-legalize-types.cpp7
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp57
-rw-r--r--source/slang/slang-ir-peephole.cpp77
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp68
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp175
-rw-r--r--source/slang/slang-ir-spirv-snippet.cpp4
-rw-r--r--source/slang/slang-ir-spirv-snippet.h5
-rw-r--r--source/slang/slang-ir-util.cpp7
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-ir-validate.cpp10
-rw-r--r--source/slang/slang-ir.cpp26
-rw-r--r--source/slang/slang-language-server-completion.cpp1
-rw-r--r--tests/expected-failure.txt42
-rw-r--r--tests/vkray/raygen.slang.glsl52
25 files changed, 893 insertions, 203 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index ad757bdbb..77ed2d51f 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -1339,7 +1339,7 @@ struct RWByteAddressBuffer
template <typename T>
SLANG_CUDA_CALL T* _getPtrAt(size_t index)
{
- SLANG_BOUND_CHECK_BYTE_ADDRESS(index, 4, sizeInBytes);
+ SLANG_BOUND_CHECK_BYTE_ADDRESS(index, sizeof(T), sizeInBytes);
return (T*)(((char*)data) + index);
}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 73c20f2d0..a3900826e 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -1048,6 +1048,9 @@ __generic<T : __BuiltinType>
__target_intrinsic(cpp, "bool($0)")
__target_intrinsic(cuda, "bool($0)")
__target_intrinsic(glsl, "bool($0)")
+__target_intrinsic(spirv, boolean(T), "OpCopyObject resultType resultId _0")
+__target_intrinsic(spirv, integral(T), "OpINotEqual resultType resultId _0 const(,0)")
+__target_intrinsic(spirv, floating(T), "OpFUnordNotEqual resultType resultId _0 const(,0)")
[__readNone]
bool any(T x);
@@ -1375,6 +1378,7 @@ matrix<uint,N,M> asuint(matrix<uint,N,M> x)
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "uint16_t(packHalf2x16(vec2($0, 0.0)))")
__target_intrinsic(cuda, "__half_as_ushort")
+__target_intrinsic(spirv, "OpBitcast resultType resultId _0")
[__readNone]
uint16_t asuint16(float16_t value);
@@ -1391,6 +1395,7 @@ matrix<uint16_t,R,C> asuint16<let R : int, let C : int>(matrix<float16_t,R,C> va
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "float16_t(unpackHalf2x16($0).x)")
__target_intrinsic(cuda, "__ushort_as_half")
+__target_intrinsic(spirv, "OpBitcast resultType resultId _0")
[__readNone]
float16_t asfloat16(uint16_t value);
@@ -1406,12 +1411,14 @@ matrix<float16_t,R,C> asfloat16<let R : int, let C : int>(matrix<uint16_t,R,C> v
__target_intrinsic(hlsl)
__target_intrinsic(cuda, "__half_as_short")
+__target_intrinsic(spirv, "OpBitcast resultType resultId _0")
[__unsafeForceInlineEarly][__readNone] int16_t asint16(float16_t value) { return asuint16(value); }
__target_intrinsic(hlsl) [__unsafeForceInlineEarly][__readNone] vector<int16_t,N> asint16<let N : int>(vector<float16_t,N> value) { return asuint16(value); }
__target_intrinsic(hlsl) [__unsafeForceInlineEarly][__readNone] matrix<int16_t,R,C> asint16<let R : int, let C : int>(matrix<float16_t,R,C> value) { return asuint16(value); }
__target_intrinsic(hlsl)
__target_intrinsic(cuda, "__short_as_half")
+__target_intrinsic(spirv, "OpBitcast resultType resultId _0")
[__readNone]
[__unsafeForceInlineEarly] float16_t asfloat16(int16_t value) { return asfloat16(asuint16(value)); }
@@ -2086,6 +2093,10 @@ __glsl_version(420)
__target_intrinsic(hlsl)
__cuda_sm_version(6.0)
__target_intrinsic(cuda, "__half2float(__ushort_as_half($0))")
+__target_intrinsic(spirv, R"(
+ %lowBits = OpUConvert _type(uint16_t) resultId _0;
+ %half = OpBitcast _type(half) resultId %lowBits;
+ OpFConvert resultType resultId %half)")
[__readNone]
float f16tof32(uint value);
@@ -2105,6 +2116,10 @@ __glsl_version(420)
__target_intrinsic(hlsl)
__cuda_sm_version(6.0)
__target_intrinsic(cuda, "__half_as_ushort(__float2half($0))")
+__target_intrinsic(spirv, R"(
+ %half = OpFConvert _type(half) resultId _0;
+ %lowBits = OpBitcast _type(uint16_t) resultId %half;
+ OpUConvert resultType resultId %lowBits)")
[__readNone]
uint f32tof16(float value);
@@ -2123,6 +2138,7 @@ vector<uint, N> f32tof16(vector<float, N> value)
__target_intrinsic(glsl, "unpackHalf2x16($0).x")
__target_intrinsic(cuda, "__half2float")
+__target_intrinsic(spirv, "OpFConvert resultType resultId _0")
__glsl_version(420)
[__readNone]
float f16tof32(float16_t value);
@@ -2130,6 +2146,7 @@ float f16tof32(float16_t value);
__generic<let N : int>
__target_intrinsic(hlsl)
__target_intrinsic(cuda, "__half2float")
+__target_intrinsic(spirv, "OpFConvert resultType resultId _0")
[__readNone]
vector<float, N> f16tof32(vector<float16_t, N> value)
{
@@ -2140,11 +2157,13 @@ vector<float, N> f16tof32(vector<float16_t, N> value)
__target_intrinsic(glsl, "packHalf2x16(vec2($0,0.0))")
__glsl_version(420)
__target_intrinsic(cuda, "__float2half")
+__target_intrinsic(spirv, "OpFConvert resultType resultId _0")
[__readNone]
float16_t f32tof16_(float value);
__generic<let N : int>
__target_intrinsic(cuda, "__float2half")
+__target_intrinsic(spirv, "OpFConvert resultType resultId _0")
[__readNone]
vector<float16_t, N> f32tof16_(vector<float, N> value)
{
diff --git a/source/slang/slang-emit-base.cpp b/source/slang/slang-emit-base.cpp
index d565edb24..7ee3ea9ca 100644
--- a/source/slang/slang-emit-base.cpp
+++ b/source/slang/slang-emit-base.cpp
@@ -1,4 +1,5 @@
#include "slang-emit-base.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -45,11 +46,7 @@ void SourceEmitterBase::handleRequiredCapabilities(IRInst* inst)
IRVarLayout* SourceEmitterBase::getVarLayout(IRInst* var)
{
- auto decoration = var->findDecoration<IRLayoutDecoration>();
- if (!decoration)
- return nullptr;
-
- return as<IRVarLayout>(decoration->getLayout());
+ return findVarLayout(var);
}
BaseType SourceEmitterBase::extractBaseType(IRType* inType)
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 60431b76c..57c074751 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -1012,14 +1012,14 @@ SpvInst* emitOpMemberDecorateUserSemantic(
}
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpVectorShuffle
-template<typename T1, typename T2, typename T3, Index N>
+template<typename T1, typename T2, typename T3>
SpvInst* emitOpVectorShuffle(
SpvInstParent* parent,
IRInst* inst,
const T1& idResultType,
const T2& vector1,
const T3& vector2,
- const Array<SpvLiteralInteger, N>& components
+ ArrayView<SpvLiteralInteger> components
)
{
static_assert(isSingular<T1>);
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 3febbd210..0878b3494 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1168,16 +1168,22 @@ struct SPIRVEmitContext
// > OpTypeInt
- case kIROp_UInt8Type:
case kIROp_UInt16Type:
+ case kIROp_Int16Type:
+ case kIROp_UInt8Type:
case kIROp_UIntType:
case kIROp_UInt64Type:
case kIROp_Int8Type:
- case kIROp_Int16Type:
case kIROp_IntType:
case kIROp_Int64Type:
{
const IntInfo i = getIntTypeInfo(as<IRType>(inst));
+ if (i.width == 16)
+ requireSPIRVCapability(SpvCapabilityInt16);
+ else if (i.width == 64)
+ requireSPIRVCapability(SpvCapabilityInt64);
+ else if (i.width == 8)
+ requireSPIRVCapability(SpvCapabilityInt8);
return emitOpTypeInt(
inst,
SpvLiteralInteger::from32(int32_t(i.width)),
@@ -1395,6 +1401,8 @@ struct SPIRVEmitContext
return emitGlobalParam(as<IRGlobalParam>(inst));
case kIROp_GlobalVar:
return emitGlobalVar(as<IRGlobalVar>(inst));
+ case kIROp_Var:
+ return emitVar(getSection(SpvLogicalSectionID::GlobalVariables), inst);
// ...
case kIROp_Specialize:
@@ -1436,6 +1444,8 @@ struct SPIRVEmitContext
}
return result;
}
+ case kIROp_GetStringHash:
+ return emitGetStringHash(inst);
default:
{
String e = "Unhandled global inst in spirv-emit:\n"
@@ -1468,6 +1478,8 @@ struct SPIRVEmitContext
void emitVarLayout(SpvInst* varInst, IRVarLayout* layout)
{
+ bool needDefaultSetBindingDecoration = false;
+ bool hasExplicitSetBinding = false;
for (auto rr : layout->getOffsetAttrs())
{
UInt index = rr->getOffset();
@@ -1498,15 +1510,6 @@ struct SPIRVEmitContext
varInst,
SpvLiteralInteger::from32(int32_t(index))
);
- if (space)
- {
- emitOpDecorateIndex(
- getSection(SpvLogicalSectionID::Annotations),
- nullptr,
- varInst,
- SpvLiteralInteger::from32(int32_t(space))
- );
- }
break;
case LayoutResourceKind::SpecializationConstant:
@@ -1527,24 +1530,44 @@ struct SPIRVEmitContext
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
- SpvLiteralInteger::from32(int32_t(index))
- );
+ SpvLiteralInteger::from32(int32_t(index)));
+ if (space)
+ {
+ emitOpDecorateDescriptorSet(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ varInst,
+ SpvLiteralInteger::from32(int32_t(space)));
+ }
+ else
+ {
+ needDefaultSetBindingDecoration = true;
+ }
+ break;
+ case LayoutResourceKind::RegisterSpace:
emitOpDecorateDescriptorSet(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
- SpvLiteralInteger::from32(int32_t(space))
- );
+ SpvLiteralInteger::from32(int32_t(index)));
+ hasExplicitSetBinding = true;
break;
default:
break;
}
}
+ if (needDefaultSetBindingDecoration && !hasExplicitSetBinding)
+ {
+ emitOpDecorateDescriptorSet(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ varInst,
+ SpvLiteralInteger::from32(int32_t(0)));
+ }
}
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
- auto layout = getVarLayout(param);
auto storageClass = SpvStorageClassUniform;
if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
{
@@ -1562,7 +1585,8 @@ struct SPIRVEmitContext
param->getDataType(),
storageClass
);
- emitVarLayout(varInst, layout);
+ if (auto layout = getVarLayout(param))
+ emitVarLayout(varInst, layout);
return varInst;
}
@@ -1615,6 +1639,8 @@ struct SPIRVEmitContext
/// Emit a declaration for the given `irFunc`
SpvInst* emitFuncDeclaration(IRFunc* irFunc)
{
+ if (irFunc->findDecorationImpl(kIROp_SPIRVOpDecoration))
+ return nullptr;
// For now we aren't handling function declarations;
// we expect to deal only with fully linked modules.
//
@@ -1852,6 +1878,10 @@ struct SPIRVEmitContext
return emitLoad(parent, as<IRLoad>(inst));
case kIROp_Store:
return emitStore(parent, as<IRStore>(inst));
+ case kIROp_SwizzledStore:
+ return emitSwizzledStore(parent, as<IRSwizzledStore>(inst));
+ case kIROp_swizzleSet:
+ return emitSwizzleSet(parent, as<IRSwizzleSet>(inst));
case kIROp_RWStructuredBufferGetElementPtr:
return emitStructuredBufferGetElementPtr(parent, inst);
case kIROp_StructuredBufferGetDimensions:
@@ -1866,10 +1896,6 @@ struct SPIRVEmitContext
return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst));
case kIROp_CastFloatToInt:
return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst));
- case kIROp_MatrixReshape:
- case kIROp_VectorReshape:
- // TODO: break emitConstruct into separate functions for each opcode.
- return emitConstruct(parent, inst);
case kIROp_BitCast:
return emitOpBitcast(
parent,
@@ -1991,6 +2017,28 @@ struct SPIRVEmitContext
return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst));
case kIROp_DebugLine:
return emitDebugLine(parent, as<IRDebugLine>(inst));
+ case kIROp_GetStringHash:
+ return emitGetStringHash(inst);
+
+ }
+ }
+
+ SpvInst* emitGetStringHash(IRInst* inst)
+ {
+ auto getStringHashInst = as<IRGetStringHash>(inst);
+ auto stringLit = getStringHashInst->getStringLit();
+
+ if (stringLit)
+ {
+ auto slice = stringLit->getStringSlice();
+ return emitIntConstant(getStableHashCode32(slice.begin(), slice.getLength()).hash, inst->getDataType());
+ }
+ else
+ {
+ // Couldn't handle
+ String e = "Unhandled local inst in spirv-emit:\n"
+ + dumpIRToString(inst, { IRDumpOptions::Mode::Detailed, 0 });
+ SLANG_UNIMPLEMENTED_X(e.getBuffer());
}
}
@@ -2236,7 +2284,14 @@ struct SPIRVEmitContext
for (auto field : structType->getFields())
{
IRIntegerValue offset = 0;
- getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset);
+ if (auto offsetDecor = field->getKey()->findDecoration<IRPackOffsetDecoration>())
+ {
+ offset = (getIntVal(offsetDecor->getRegisterOffset()) * 4 + getIntVal(offsetDecor->getComponentOffset())) * 4;
+ }
+ else
+ {
+ getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset);
+ }
emitOpMemberDecorateOffset(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
@@ -2368,14 +2423,168 @@ struct SPIRVEmitContext
{
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
- if (semanticName == "sv_dispatchthreadid")
+ if (semanticName == "sv_position")
+ {
+ auto importDecor = inst->findDecoration<IRImportDecoration>();
+ if (importDecor->getMangledName() == "gl_FragCoord")
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragCoord);
+ else
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPosition);
+ }
+ else if (semanticName == "sv_target")
+ {
+ // Note: we do *not* need to generate some kind of `gl_`
+ // builtin for fragment-shader outputs: they are just
+ // ordinary `out` variables, with ordinary `location`s,
+ // as far as GLSL is concerned.
+ return nullptr;
+ }
+ else if (semanticName == "sv_clipdistance")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInClipDistance);
+ }
+ else if (semanticName == "sv_culldistance")
+ {
+ requireSPIRVCapability(SpvCapabilityCullDistance);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullDistance);
+ }
+ else if (semanticName == "sv_coverage")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleMask);
+ }
+ else if (semanticName == "sv_innercoverage")
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentFullyCoveredEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFullyCoveredEXT);
+ }
+ else if (semanticName == "sv_depth")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ }
+ else if (semanticName == "sv_depthgreaterequal")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ }
+ else if (semanticName == "sv_depthlessequal")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ }
+ else if (semanticName == "sv_dispatchthreadid")
{
return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId);
}
+ else if (semanticName == "sv_domainlocation")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessCoord);
+ }
+ else if (semanticName == "sv_groupid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInWorkgroupId);
+ }
else if (semanticName == "sv_groupindex")
{
return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationIndex);
}
+ else if (semanticName == "sv_groupthreadid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationId);
+ }
+ else if (semanticName == "sv_gsinstanceid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId);
+ }
+ else if (semanticName == "sv_instanceid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInstanceIndex);
+ }
+ else if (semanticName == "sv_isfrontface")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFrontFacing);
+ }
+ else if (semanticName == "sv_outputcontrolpointid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId);
+ }
+ else if (semanticName == "sv_pointsize")
+ {
+ // float in hlsl & glsl
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPointSize);
+ }
+ else if (semanticName == "sv_primitiveid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveId);
+ }
+ else if (semanticName == "sv_rendertargetarrayindex")
+ {
+ requireSPIRVCapability(SpvCapabilityShaderLayer);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLayer);
+ }
+ else if (semanticName == "sv_sampleindex")
+ {
+ requireSPIRVCapability(SpvCapabilitySampleRateShading);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleId);
+ }
+ else if (semanticName == "sv_stencilref")
+ {
+ requireSPIRVCapability(SpvCapabilityStencilExportEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragStencilRefEXT);
+ }
+ else if (semanticName == "sv_tessfactor")
+ {
+ requireSPIRVCapability(SpvCapabilityTessellation);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelOuter);
+ }
+ else if (semanticName == "sv_vertexid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInVertexId);
+ }
+ else if (semanticName == "sv_viewid")
+ {
+ requireSPIRVCapability(SpvCapabilityMultiView);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewIndex);
+ }
+ else if (semanticName == "sv_viewportarrayindex")
+ {
+ requireSPIRVCapability(SpvCapabilityShaderViewportIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportIndex);
+ }
+ else if (semanticName == "nv_x_right")
+ {
+ SLANG_UNIMPLEMENTED_X("spirv emit for nv_x_right");
+ }
+ else if (semanticName == "nv_viewport_mask")
+ {
+ requireSPIRVCapability(SpvCapabilityPerViewAttributesNV);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportMaskPerViewNV);
+ }
+ else if (semanticName == "sv_barycentrics")
+ {
+ if (m_targetRequest->getTargetCaps().implies(CapabilityAtom::GL_NV_fragment_shader_barycentric))
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentBarycentricNV);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordNV);
+ }
+ else
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR);
+ }
+
+ // TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which
+ // we ought to use if the `noperspective` modifier has been
+ // applied to this varying input.
+ }
+ else if (semanticName == "sv_cullprimitive")
+ {
+ requireSPIRVCapability(SpvCapabilityMeshShadingEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullPrimitiveEXT);
+ }
+ else if (semanticName == "sv_shadingrate")
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentShadingRateKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveShadingRateKHR);
+ }
+ SLANG_UNREACHABLE("Unimplemented system value in spirv emit.");
}
}
return nullptr;
@@ -2571,6 +2780,14 @@ struct SPIRVEmitContext
{
return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic);
}
+ else if (auto spvOpDecor = funcValue->findDecorationImpl(kIROp_SPIRVOpDecoration))
+ {
+ SpvOp op = (SpvOp)getIntVal(spvOpDecor->getOperand(0));
+ List<IRInst*> args;
+ for (UInt i = 0; i < inst->getArgCount(); i++)
+ args.add(inst->getArg(i));
+ return emitInst(parent, inst, op, inst->getFullType(), kResultID, args);
+ }
else
{
return emitOpFunctionCall(
@@ -2658,6 +2875,9 @@ struct SPIRVEmitContext
case SpvSnippet::ASMType::Int:
result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType());
break;
+ case SpvSnippet::ASMType::UInt16:
+ result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getType(kIROp_UInt16Type));
+ break;
case SpvSnippet::ASMType::UInt2:
{
auto uintType = builder.getType(kIROp_UIntType);
@@ -2686,12 +2906,18 @@ struct SPIRVEmitContext
case SpvSnippet::ASMType::Float:
irType = builder.getType(kIROp_FloatType);
break;
+ case SpvSnippet::ASMType::Half:
+ irType = builder.getType(kIROp_HalfType);
+ break;
case SpvSnippet::ASMType::Int:
irType = builder.getIntType();
break;
case SpvSnippet::ASMType::UInt:
irType = builder.getUIntType();
break;
+ case SpvSnippet::ASMType::UInt16:
+ irType = builder.getType(kIROp_UInt16Type);
+ break;
case SpvSnippet::ASMType::Float2:
irType = builder.getVectorType(
builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2));
@@ -2969,6 +3195,46 @@ struct SPIRVEmitContext
return emitOpStore(parent, inst, inst->getPtr(), inst->getVal());
}
+ SpvInst* emitSwizzledStore(SpvInstParent* parent, IRSwizzledStore* inst)
+ {
+ auto sourceVectorType = as<IRVectorType>(inst->getSource()->getDataType());
+ SLANG_ASSERT(sourceVectorType);
+ auto sourceElementType = sourceVectorType->getElementType();
+ SLANG_ASSERT(getIntVal(sourceVectorType->getElementCount()) == (IRIntegerValue)inst->getElementCount());
+ SpvInst* result = nullptr;
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto destPtrType = as<IRPtrTypeBase>(inst->getDest()->getDataType());
+ SpvStorageClass addrSpace = SpvStorageClassFunction;
+ if (destPtrType->hasAddressSpace())
+ addrSpace = (SpvStorageClass)destPtrType->getAddressSpace();
+ auto ptrElementType = builder.getPtrType(kIROp_PtrType, sourceElementType, addrSpace);
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto index = inst->getElementIndex(i);
+ auto addr = emitOpAccessChain(parent, nullptr, ptrElementType, inst->getDest(), makeArray(index));
+ auto val = emitOpCompositeExtract(parent, nullptr, sourceElementType, inst->getSource(), makeArray(SpvLiteralInteger::from32((int32_t)i)));
+ result = emitOpStore(parent, (i == inst->getElementCount() - 1 ? inst : nullptr), addr, val);
+ }
+ return result;
+ }
+
+ SpvInst* emitSwizzleSet(SpvInstParent* parent, IRSwizzleSet* inst)
+ {
+ auto resultVectorType = as<IRVectorType>(inst->getDataType());
+ List<SpvLiteralInteger> shuffleIndices;
+ shuffleIndices.setCount((Index)getIntVal(resultVectorType->getElementCount()));
+ for (Index i = 0; i < shuffleIndices.getCount(); i++)
+ shuffleIndices[i] = SpvLiteralInteger::from32((int32_t)i);
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto destIndex = (int32_t)getIntVal(inst->getElementIndex(i));
+ SLANG_ASSERT(destIndex < shuffleIndices.getCount());
+ shuffleIndices[destIndex] = SpvLiteralInteger::from32((int32_t)(i + shuffleIndices.getCount()));
+ }
+ return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView());
+ }
+
SpvInst* emitStructuredBufferGetElementPtr(SpvInstParent* parent, IRInst* inst)
{
//"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;"
@@ -3035,6 +3301,40 @@ struct SPIRVEmitContext
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
const auto fromType = dropVector(fromTypeV);
const auto toType = dropVector(toTypeV);
+
+ if (as<IRBoolType>(fromType))
+ {
+ // Cast from bool to int.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto zero = builder.getIntValue(toType, 0);
+ auto one = builder.getIntValue(toType, 1);
+ if (auto vecType = as<IRVectorType>(toTypeV))
+ {
+ auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount()));
+ auto oneV = emitSplat(parent, nullptr, one, getIntVal(vecType->getElementCount()));
+ return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0),
+ oneV, zeroV);
+ }
+ return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0), one, zero);
+ }
+ else if (as<IRBoolType>(toType))
+ {
+ // Cast from int to bool.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto zero = builder.getIntValue(fromType, 0);
+ if (auto vecType = as<IRVectorType>(toTypeV))
+ {
+ auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount()));
+ return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zeroV);
+ }
+ else
+ {
+ return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zero);
+ }
+ }
+
SLANG_ASSERT(isIntegralType(fromType));
SLANG_ASSERT(isIntegralType(toType));
@@ -3100,6 +3400,24 @@ struct SPIRVEmitContext
const auto fromType = dropVector(fromTypeV);
const auto toType = dropVector(toTypeV);
SLANG_ASSERT(isFloatingType(fromType));
+
+ if (as<IRBoolType>(toType))
+ {
+ // Float to bool cast.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto zero = builder.getIntValue(fromType, 0);
+ if (auto vecType = as<IRVectorType>(toTypeV))
+ {
+ auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount()));
+ return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zeroV);
+ }
+ else
+ {
+ return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zero);
+ }
+ }
+
SLANG_ASSERT(isIntegralType(toType));
const auto toInfo = getIntTypeInfo(toType);
@@ -3279,14 +3597,9 @@ struct SPIRVEmitContext
}
}
- SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
+ SpvInst* emitVectorOrScalarArithmetic(SpvInstParent* parent, IRInst* instToRegister, IRInst* type, IROp op, UInt operandCount, ArrayView<IRInst*> operands)
{
- IRType* elementType = dropVector(inst->getOperand(0)->getDataType());
- if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
- {
- //TODO: implement.
- SLANG_ASSERT(!"unimplemented: matrix arithemetic");
- }
+ IRType* elementType = dropVector(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
bool isFloatingPoint = false;
bool isBool = false;
@@ -3305,7 +3618,7 @@ struct SPIRVEmitContext
}
SpvOp opCode = SpvOpUndef;
bool isSigned = isSignedType(basicType);
- switch (inst->getOp())
+ switch (op)
{
case kIROp_Add:
opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
@@ -3323,30 +3636,30 @@ struct SPIRVEmitContext
opCode = isSigned ? SpvOpSRem : SpvOpUMod;
break;
case kIROp_FRem:
- opCode = SpvOpFRem;
+ opCode = SpvOpFMod;
break;
case kIROp_Less:
opCode = isFloatingPoint ? SpvOpFOrdLessThan
- : isSigned ? SpvOpSLessThan : SpvOpULessThan;
+ : isSigned ? SpvOpSLessThan : SpvOpULessThan;
break;
case kIROp_Leq:
opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
- : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual;
+ : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual;
break;
case kIROp_Eql:
opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
break;
case kIROp_Neq:
opCode = isFloatingPoint ? SpvOpFOrdNotEqual
- : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual;
+ : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual;
break;
case kIROp_Geq:
opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
- : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual;
+ : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual;
break;
case kIROp_Greater:
opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
- : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan;
+ : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan;
break;
case kIROp_Neg:
opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
@@ -3361,16 +3674,28 @@ struct SPIRVEmitContext
opCode = SpvOpLogicalNot;
break;
case kIROp_BitAnd:
- opCode = SpvOpBitwiseAnd;
+ if (isBool)
+ opCode = SpvOpLogicalAnd;
+ else
+ opCode = SpvOpBitwiseAnd;
break;
case kIROp_BitOr:
- opCode = SpvOpBitwiseOr;
+ if (isBool)
+ opCode = SpvOpLogicalOr;
+ else
+ opCode = SpvOpBitwiseOr;
break;
case kIROp_BitXor:
- opCode = SpvOpBitwiseXor;
+ if (isBool)
+ opCode = SpvOpLogicalNotEqual;
+ else
+ opCode = SpvOpBitwiseXor;
break;
case kIROp_BitNot:
- opCode = SpvOpBitReverse;
+ if (isBool)
+ opCode = SpvOpLogicalNot;
+ else
+ opCode = SpvOpBitReverse;
break;
case kIROp_Rsh:
opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
@@ -3382,20 +3707,20 @@ struct SPIRVEmitContext
SLANG_ASSERT(!"unknown arithmetic opcode");
break;
}
- if(inst->getOperandCount() == 1)
+ if (operandCount == 1)
{
- return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ return emitInst(parent, instToRegister, opCode, type, kResultID, operands);
}
- else if(inst->getOperandCount() == 2)
+ else if (operandCount == 2)
{
- auto l = inst->getOperand(0);
+ auto l = operands[0];
const auto lVec = as<IRVectorType>(l->getDataType());
- auto r = inst->getOperand(1);
+ auto r = operands[1];
const auto rVec = as<IRVectorType>(r->getDataType());
- const auto go = [&](const auto l, const auto r){
- return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r);
+ const auto go = [&](const auto l, const auto r) {
+ return emitInst(parent, instToRegister, opCode, type, kResultID, l, r);
};
- if(lVec && !rVec)
+ if (lVec && !rVec)
{
const auto len = as<IRIntLit>(lVec->getElementCount());
SLANG_ASSERT(len);
@@ -3412,6 +3737,45 @@ struct SPIRVEmitContext
SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
+
+ SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
+ {
+ if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
+ {
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ List<SpvInst*> rows;
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ List<IRInst*> operands;
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto originalOperand = inst->getOperand(j);
+ if (as<IRMatrixType>(originalOperand->getDataType()))
+ {
+ auto operand = builder.emitElementExtract(originalOperand, i);
+ emitLocalInst(parent, operand);
+ operands.add(operand);
+ }
+ else
+ {
+ operands.add(originalOperand);
+ }
+ }
+ rows.add(emitVectorOrScalarArithmetic(parent, nullptr, rowVectorType, inst->getOp(), inst->getOperandCount(), operands.getArrayView()));
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), rows);
+ }
+
+ Array<IRInst*, 4> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ operands.add(inst->getOperand(i));
+ return emitVectorOrScalarArithmetic(parent, inst, inst->getDataType(), inst->getOp(), inst->getOperandCount(), operands.getView());
+ }
+
SpvInst* emitDebugLine(SpvInstParent* parent, IRDebugLine* debugLine)
{
auto scope = findDebugScope(debugLine);
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index ed7818dbf..200f8bc55 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -241,7 +241,6 @@ namespace Slang
{
case kIROp_IntType:
case kIROp_FloatType:
- case kIROp_BoolType:
#if SLANG_PTR_IS_32
case kIROp_IntPtrType:
#endif
@@ -260,6 +259,23 @@ namespace Slang
advanceOffset(4);
break;
}
+ case kIROp_BoolType:
+ {
+ ensureOffsetAt4ByteBoundary();
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcVal = builder->emitLoad(concreteVar);
+ IRInst* args[] = {srcVal, builder->getIntValue(builder->getUIntType(), 1), builder->getIntValue(builder->getUIntType(), 0) };
+ auto dstVal = builder->emitIntrinsicInst(builder->getUIntType(), kIROp_Select, 3, args);
+ auto dstAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ builder->emitStore(dstAddr, dstVal);
+ }
+ advanceOffset(4);
+ break;
+ }
case kIROp_UIntType:
#if SLANG_PTR_IS_32
case kIROp_UIntPtrType:
@@ -416,7 +432,6 @@ namespace Slang
{
case kIROp_IntType:
case kIROp_FloatType:
- case kIROp_BoolType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
@@ -432,6 +447,22 @@ namespace Slang
advanceOffset(4);
break;
}
+ case kIROp_BoolType:
+ {
+ ensureOffsetAt4ByteBoundary();
+ if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
+ {
+ auto srcAddr = builder->emitFieldAddress(
+ uintPtrType,
+ anyValueVar,
+ anyValInfo->fieldKeys[fieldOffset]);
+ auto srcVal = builder->emitLoad(srcAddr);
+ srcVal = builder->emitNeq(srcVal, builder->getIntValue(builder->getUIntType(), 0));
+ builder->emitStore(concreteVar, srcVal);
+ }
+ advanceOffset(4);
+ break;
+ }
case kIROp_UIntType:
{
ensureOffsetAt4ByteBoundary();
diff --git a/source/slang/slang-ir-bind-existentials.cpp b/source/slang/slang-ir-bind-existentials.cpp
index 35c07452a..f4fbf6e20 100644
--- a/source/slang/slang-ir-bind-existentials.cpp
+++ b/source/slang/slang-ir-bind-existentials.cpp
@@ -351,8 +351,7 @@ struct BindExistentialSlots
inst,
slotOperandCount,
slotOperands.getBuffer());
-
- use->set(newVal);
+ builder.replaceOperand(use, newVal);
}
}
};
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 5c4e1d037..4c5b58882 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -1488,8 +1488,24 @@ ScalarizedVal adaptType(
IRBuilder* builder,
IRInst* val,
IRType* toType,
- IRType* /*fromType*/)
+ IRType* fromType)
{
+ if (auto fromVector = as<IRVectorType>(fromType))
+ {
+ if (auto toVector = as<IRVectorType>(toType))
+ {
+ if (fromVector->getElementCount() != toVector->getElementCount())
+ {
+ fromType = builder->getVectorType(fromVector->getElementType(), toVector->getElementCount());
+ val = builder->emitVectorReshape(fromType, val);
+ }
+ }
+ else if (auto toBasicType = as<IRBasicType>(toType))
+ {
+ UInt index = 0;
+ val = builder->emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ }
+ }
// TODO: actually consider what needs to go on here...
return ScalarizedVal::value(builder->emitCast(
toType,
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index e878c56f7..df245f555 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -84,6 +84,10 @@ struct InliningPassBase
{
changed = considerCallSiteInFunc(func);
}
+ else if (auto call = as<IRCall>(inst))
+ {
+ considerCallSite(call);
+ }
// Recursively consider the children of inst.
for (auto child : inst->getModifiableChildren())
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 22d8c5538..0b2cc790d 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -99,7 +99,9 @@ struct IRTargetSpecificDecoration : IRDecoration
IRType* getTypeScrutinee()
{
SLANG_ASSERT(getOperandCount() == 4);
- const auto t = as<IRType>(getOperand(3));
+ // Note: cannot use as<IRType> here because the operand can be
+ // an `IRParam` representing a generic type.
+ const auto t = (IRType*)(getOperand(3));
SLANG_ASSERT(t);
return t;
}
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index 244ff2039..cba6894a9 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -404,9 +404,10 @@ struct Std430LayoutRules : IRTypeLayoutRules
}
virtual IRSizeAndAlignment getVectorSizeAndAlignment(IRSizeAndAlignment element, IRIntegerValue count)
{
+ IRIntegerValue countForAlignment = count;
if (count == 3)
- count = 4;
- return IRSizeAndAlignment((int)(element.size * count), (int)(element.size * count));
+ countForAlignment = 4;
+ return IRSizeAndAlignment((int)(element.size * count), (int)(element.size * countForAlignment));
}
};
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index c2d7ba95f..cb1cf3db3 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -1954,13 +1954,6 @@ static LegalVal legalizeInst(
}
}
-IRVarLayout* findVarLayout(IRInst* value)
-{
- if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>())
- return as<IRVarLayout>(layoutDecoration->getLayout());
- return nullptr;
-}
-
static UnownedStringSlice findNameHint(IRInst* inst)
{
if( auto nameHintDecoration = inst->findDecoration<IRNameHintDecoration>() )
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 909ffea83..bf87d72fe 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -224,10 +224,18 @@ namespace Slang
if (auto matrixType = as<IRMatrixType>(type))
{
- if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout)
+ // For spirv, we always want to lower all matrix types, because matrix types
+ // are considered abstract types.
+ if (!target->shouldEmitSPIRVDirectly())
{
- info.loweredType = type;
- return info;
+ // For other targets, we only lower the matrix types if they differ from the default
+ // matrix layout.
+ if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
+ rules->ruleName == IRTypeLayoutRuleName::Natural)
+ {
+ info.loweredType = type;
+ return info;
+ }
}
auto loweredType = builder.createStructType();
@@ -264,7 +272,7 @@ namespace Slang
else if (auto arrayType = as<IRArrayType>(type))
{
auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules);
- if (!loweredInnerTypeInfo.convertLoweredToOriginal && rules->ruleName == IRTypeLayoutRuleName::Natural)
+ if (!loweredInnerTypeInfo.convertLoweredToOriginal)
{
info.loweredType = type;
return info;
@@ -378,6 +386,45 @@ namespace Slang
return info;
}
+ switch (target->getTarget())
+ {
+ case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
+ if (as<IRBoolType>(type))
+ {
+ // Bool is an abstract type in SPIRV, so we need to lower them into an int.
+ info.loweredType = builder.getIntType();
+ // Create unpack func.
+ {
+ builder.setInsertAfter(type);
+ info.convertLoweredToOriginal = builder.createFunc();
+ builder.setInsertInto(info.convertLoweredToOriginal);
+ builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage"));
+ info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&info.loweredType, type));
+ builder.emitBlock();
+ auto loweredParam = builder.emitParam(info.loweredType);
+ auto result = builder.emitCast(type, loweredParam);
+ builder.emitReturn(result);
+ }
+
+ // Create pack func.
+ {
+ builder.setInsertAfter(info.convertLoweredToOriginal);
+ info.convertOriginalToLowered = builder.createFunc();
+ builder.setInsertInto(info.convertOriginalToLowered);
+ builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage"));
+ info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, info.loweredType));
+ builder.emitBlock();
+ auto param = builder.emitParam(type);
+ auto result = builder.emitCast(info.loweredType, param);
+ builder.emitReturn(result);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+
info.loweredType = type;
return info;
}
@@ -506,7 +553,7 @@ namespace Slang
builder.setInsertBefore(user);
auto newLoad = cloneInst(&cloneEnv, &builder, user);
newLoad->setFullType(loweredElementTypeInfo.loweredType);
- auto unpackedVal = builder.emitCallInst(elementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad);
+ auto unpackedVal = builder.emitCallInst((IRType*)originalElementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad);
user->replaceUsesWith(unpackedVal);
user->removeAndDeallocate();
break;
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 600361b2f..d83c6dccd 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -683,6 +683,83 @@ struct PeepholeContext : InstPassBase
}
}
break;
+ case kIROp_VectorReshape:
+ {
+ auto fromType = as<IRVectorType>(inst->getOperand(0)->getDataType());
+ auto resultType = as<IRVectorType>(inst->getDataType());
+ if (!resultType)
+ {
+ if (!fromType)
+ {
+ inst->replaceUsesWith(inst->getOperand(0));
+ maybeRemoveOldInst(inst);
+ changed = true;
+ break;
+ }
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ UInt index = 0;
+ auto newInst = builder.emitSwizzle(resultType, inst->getOperand(0), 1, &index);
+ inst->replaceUsesWith(newInst);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ break;
+ }
+ auto fromCount = as<IRIntLit>(fromType->getElementCount());
+ if (!fromCount)
+ break;
+ auto toCount = as<IRIntLit>(resultType->getElementCount());
+ if (!toCount)
+ break;
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newInst = builder.emitVectorReshape(resultType, inst->getOperand(0));
+ if (newInst != inst)
+ {
+ inst->replaceUsesWith(newInst);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ }
+ break;
+ case kIROp_MatrixReshape:
+ {
+ auto fromType = as<IRMatrixType>(inst->getOperand(0)->getDataType());
+ auto resultType = as<IRMatrixType>(inst->getDataType());
+ SLANG_ASSERT(fromType && resultType);
+ auto fromRows = as<IRIntLit>(fromType->getRowCount());
+ if (!fromRows) break;
+ auto fromCols = as<IRIntLit>(fromType->getColumnCount());
+ if (!fromCols) break;
+ auto toRows = as<IRIntLit>(resultType->getRowCount());
+ if (!toRows) break;
+ auto toCols = as<IRIntLit>(resultType->getColumnCount());
+ if (!toCols) break;
+ List<IRInst*> rows;
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto toRowType = builder.getVectorType(resultType->getElementType(), resultType->getColumnCount());
+ for (IRIntegerValue i = 0; i < toRows->getValue(); i++)
+ {
+ if (i < fromRows->getValue())
+ {
+ auto originalRow = builder.emitElementExtract(inst->getOperand(0), i);
+ auto resizedRow = builder.emitVectorReshape(toRowType, originalRow);
+ rows.add(resizedRow);
+ }
+ else
+ {
+ auto zero = builder.emitDefaultConstruct(resultType->getElementType());
+ auto row = builder.emitMakeVectorFromScalar(toRowType, zero);
+ rows.add(row);
+ }
+ }
+ auto newInst = builder.emitMakeMatrix(resultType, (UInt)rows.getCount(), rows.getBuffer());
+ inst->replaceUsesWith(newInst);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ break;
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 63c40b16f..e848d11c1 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -473,17 +473,74 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst)
static bool trySimplifySwitch(IRBuilder& builder, IRSwitch* switchInst)
{
+ // First, we fuse switch case blocks that is a trivial branch.
+ // If we see:
+ // ```
+ // someBlock:
+ // switch(..., case_block_A, ...)
+ // case_block_A:
+ // branch blockB;
+ // ```
+ // Then we fold blockB into the switch case operand:
+ // ```
+ // someBlock:
+ // switch(..., blockB, ...)
+ // ```
+ // We can do this if `blockB` is not a merge block.
+ //
+ bool changed = false;
+ auto fuseSwitchCaseBlock = [&](IRUse* targetUse)
+ {
+ for (;;)
+ {
+ auto block = as<IRBlock>(targetUse->get());
+ if (block->getFirstInst()->getOp() != kIROp_unconditionalBranch)
+ return;
+ auto branch = as<IRUnconditionalBranch>(block->getFirstInst());
+ // We can't fuse the block if there are phi arguments.
+ if (branch->getArgCount() != 0)
+ return;
+ auto target = branch->getTargetBlock();
+ if (target == switchInst->getBreakLabel())
+ return;
+ // target must not be used as a merge block of other control flow constructs.
+ for (auto use = target->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() == switchInst || use->getUser() == branch)
+ continue;
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_loop:
+ case kIROp_ifElse:
+ case kIROp_Switch:
+ // If the target block is used by a special control flow inst,
+ // it is likely a merge block and we can't fuse it.
+ return;
+ default:
+ break;
+ }
+ }
+ targetUse->set(target);
+ changed = true;
+ }
+ };
+
+ fuseSwitchCaseBlock(&switchInst->defaultLabel);
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ fuseSwitchCaseBlock(switchInst->getCaseLabelUse(i));
+
+ // Next, we check if all switch cases are jumping to the same target.
if (!isTrivialSwitch(switchInst))
- return false;
+ return changed;
if (switchInst->getCaseCount() == 0)
- return false;
+ return changed;
auto termInst = as<IRUnconditionalBranch>(switchInst->getCaseLabel(0)->getTerminator());
if (!termInst)
- return false;
+ return changed;
if (!arePhiArgsEquivalentInBranches(switchInst))
- return false;
+ return changed;
List<IRInst*> args;
for (UInt i = 0; i < termInst->getArgCount(); i++)
@@ -663,7 +720,7 @@ static bool removeTrivialPhiParams(IRBlock* block)
}
if (targetVal)
{
- params[i]->replaceUsesWith(args[i].knownValue);
+ params[i]->replaceUsesWith(targetVal);
params[i]->removeAndDeallocate();
for (auto termInst : termInsts)
termInst->removeArgument((UInt)i);
@@ -709,6 +766,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
{
loop->continueBlock.set(loop->getTargetBlock());
continueBlock->removeAndDeallocate();
+ changed = true;
}
// If there isn't any actual back jumps into loop target and there is a trivial
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 36fdbd56a..fd23d1b5e 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -97,23 +97,15 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Figure out storage class based on var layout.
if (auto layout = getVarLayout(inst))
{
- if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>())
+ auto cls = getGlobalParamStorageClass(layout);
+ if (cls != SpvStorageClassMax)
+ storageClass = cls;
+ else if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>())
{
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
- if (semanticName == "sv_dispatchthreadid")
- {
- storageClass = SpvStorageClassInput;
- }
- else if (semanticName == "sv_groupindex")
- {
+ if (semanticName == "sv_pointsize")
storageClass = SpvStorageClassInput;
- }
- }
- else if(const auto parameterGroupTypeLayout =
- as<IRParameterGroupTypeLayout>(layout->getTypeLayout()))
- {
- storageClass = SpvStorageClassUniform;
}
}
@@ -121,10 +113,13 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(m_sharedContext->m_irModule);
bool needLoad = true;
auto innerType = inst->getFullType();
- if (as<IRConstantBufferType>(innerType) || as<IRParameterBlockType>(innerType))
+ auto cbufferType = as<IRConstantBufferType>(innerType);
+ auto paramBlockType = as<IRParameterBlockType>(innerType);
+ if (cbufferType || paramBlockType)
{
innerType = as<IRUniformParameterGroupType>(innerType)->getElementType();
- storageClass = SpvStorageClassUniform;
+ if (storageClass == SpvStorageClassPrivate)
+ storageClass = SpvStorageClassUniform;
// Constant buffer is already treated like a pointer type, and
// we are not adding another layer of indirection when replacing it
// with a pointer type. Therefore we don't need to insert a load at
@@ -137,6 +132,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
innerType = wrapConstantBufferElement(inst);
}
builder.addDecoration(innerType, kIROp_SPIRVBlockDecoration);
+
+ auto varLayoutInst = inst->findDecoration<IRLayoutDecoration>();
+ if (paramBlockType)
+ {
+ // A parameter block typed global parameter will have a VarLayout
+ // that contains an OffsetAttr(RegisterSpace, spaceId).
+ // We need to turn this VarLayout into a standard cbuffer VarLayout
+ // in the form of OffsetAttr(ConstantBuffer, 0, spaceId).
+ builder.setInsertBefore(inst);
+ IRVarLayout* varLayout = nullptr;
+ if (varLayoutInst)
+ varLayout = as<IRVarLayout>(varLayoutInst->getLayout());
+ if (varLayout)
+ {
+ auto registerSpaceOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::RegisterSpace);
+ if (registerSpaceOffsetAttr)
+ {
+ List<IRInst*> operands;
+ for (UInt i = 0; i < varLayout->getOperandCount(); i++)
+ operands.add(varLayout->getOperand(i));
+ operands.add(builder.getVarOffsetAttr(LayoutResourceKind::ConstantBuffer, 0, registerSpaceOffsetAttr->getOffset()));
+ auto newLayout = builder.getVarLayout(operands);
+ varLayoutInst->setOperand(0, newLayout);
+ varLayout->removeAndDeallocate();
+ }
+ }
+ }
+ else if (storageClass == SpvStorageClassPushConstant)
+ {
+ // Push constant params does not need a VarLayout.
+ varLayoutInst->removeAndDeallocate();
+ }
}
// Make a pointer type of storageClass.
@@ -162,6 +189,70 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
processGlobalVar(inst);
}
+ SpvStorageClass getStorageClassFromGlobalParamResourceKind(LayoutResourceKind kind)
+ {
+ SpvStorageClass storageClass = SpvStorageClassMax;
+ switch (kind)
+ {
+ case LayoutResourceKind::Uniform:
+ case LayoutResourceKind::DescriptorTableSlot:
+ case LayoutResourceKind::ConstantBuffer:
+ storageClass = SpvStorageClassUniform;
+ break;
+ case LayoutResourceKind::VaryingInput:
+ storageClass = SpvStorageClassInput;
+ break;
+ case LayoutResourceKind::VaryingOutput:
+ storageClass = SpvStorageClassOutput;
+ break;
+ case LayoutResourceKind::ShaderResource:
+ case LayoutResourceKind::UnorderedAccess:
+ storageClass = SpvStorageClassStorageBuffer;
+ break;
+ case LayoutResourceKind::PushConstantBuffer:
+ storageClass = SpvStorageClassPushConstant;
+ break;
+ case LayoutResourceKind::RayPayload:
+ storageClass = SpvStorageClassRayPayloadKHR;
+ break;
+ case LayoutResourceKind::CallablePayload:
+ storageClass = SpvStorageClassCallableDataKHR;
+ break;
+ case LayoutResourceKind::HitAttributes:
+ storageClass = SpvStorageClassHitAttributeKHR;
+ break;
+ case LayoutResourceKind::ShaderRecord:
+ storageClass = SpvStorageClassShaderRecordBufferKHR;
+ break;
+ default:
+ break;
+ }
+ return storageClass;
+ }
+
+ SpvStorageClass getGlobalParamStorageClass(IRVarLayout* varLayout)
+ {
+ SpvStorageClass result = SpvStorageClassMax;
+ for (auto rr : varLayout->getOffsetAttrs())
+ {
+ auto storageClass = getStorageClassFromGlobalParamResourceKind(rr->getResourceKind());
+ // If we haven't inferred a storage class yet, use the one we just found.
+ if (result == SpvStorageClassMax)
+ result = storageClass;
+ else if (result != storageClass)
+ {
+ // If we have inferred a storage class, and it is different from the one we just found,
+ // then we have conflicting uses of the resource, and we cannot infer a storage class.
+ // An exception is that a uniform storage class can be further specialized by PushConstants.
+ if (result == SpvStorageClassUniform)
+ result = storageClass;
+ else
+ SLANG_UNEXPECTED("Var layout contains conflicting resource uses, cannot resolve a storage class.");
+ }
+ }
+ return result;
+ }
+
void processGlobalVar(IRInst* inst)
{
auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType());
@@ -184,31 +275,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
else if (const auto varLayout = getVarLayout(inst))
{
- for (auto rr : varLayout->getOffsetAttrs())
- {
- switch (rr->getResourceKind())
- {
- case LayoutResourceKind::Uniform:
- case LayoutResourceKind::ShaderResource:
- case LayoutResourceKind::DescriptorTableSlot:
- storageClass = SpvStorageClassUniform;
- break;
- case LayoutResourceKind::VaryingInput:
- storageClass = SpvStorageClassInput;
- break;
- case LayoutResourceKind::VaryingOutput:
- storageClass = SpvStorageClassOutput;
- break;
- case LayoutResourceKind::UnorderedAccess:
- storageClass = SpvStorageClassStorageBuffer;
- break;
- case LayoutResourceKind::PushConstantBuffer:
- storageClass = SpvStorageClassPushConstant;
- break;
- default:
- break;
- }
- }
+ auto cls = getGlobalParamStorageClass(varLayout);
+ if (cls != SpvStorageClassMax)
+ storageClass = cls;
}
IRBuilder builder(m_sharedContext->m_irModule);
@@ -312,15 +381,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
writeBacks.add(WriteBackPair{ arg, tempVar });
}
SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount());
- auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs);
- for (auto wb : writeBacks)
+ if (writeBacks.getCount())
{
- auto newVal = builder.emitLoad(wb.tempVar);
- builder.emitStore(wb.originalAddrArg, newVal);
+ auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs);
+ for (auto wb : writeBacks)
+ {
+ auto newVal = builder.emitLoad(wb.tempVar);
+ builder.emitStore(wb.originalAddrArg, newVal);
+ }
+ inst->replaceUsesWith(newCall);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newCall);
}
- inst->replaceUsesWith(newCall);
- inst->removeAndDeallocate();
- addUsersToWorkList(newCall);
}
Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar;
@@ -346,12 +418,17 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(m_sharedContext->m_irModule);
IRInst* y = nullptr;
+ builder.setInsertBefore(inst);
if (!m_mapArrayValueToVar.tryGetValue(x, y))
{
- setInsertAfterOrdinaryInst(&builder, x);
+ if (x->getParent()->getOp() == kIROp_Module)
+ builder.setInsertBefore(inst);
+ else
+ setInsertAfterOrdinaryInst(&builder, x);
y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
builder.emitStore(y, x);
- m_mapArrayValueToVar.set(x, y);
+ if (x->getParent()->getOp() != kIROp_Module)
+ m_mapArrayValueToVar.set(x, y);
}
builder.setInsertBefore(inst);
for(Index i = indices.getCount() - 1; i >= 0; --i)
diff --git a/source/slang/slang-ir-spirv-snippet.cpp b/source/slang/slang-ir-spirv-snippet.cpp
index c19d50ece..98466e8ff 100644
--- a/source/slang/slang-ir-spirv-snippet.cpp
+++ b/source/slang/slang-ir-spirv-snippet.cpp
@@ -28,6 +28,8 @@ SpvSnippet::ASMType parseASMType(Slang::Misc::TokenReader& tokenReader)
return SpvSnippet::ASMType::Double;
else if (word == "uint2")
return SpvSnippet::ASMType::UInt2;
+ else if (word == "uint16_t")
+ return SpvSnippet::ASMType::UInt16;
else if (word == "float2")
return SpvSnippet::ASMType::Float2;
else if (word == "int")
@@ -36,6 +38,8 @@ SpvSnippet::ASMType parseASMType(Slang::Misc::TokenReader& tokenReader)
return SpvSnippet::ASMType::UInt;
else if (word == "_p")
return SpvSnippet::ASMType::FloatOrDouble;
+ else if (word == "half")
+ return SpvSnippet::ASMType::Half;
return SpvSnippet::ASMType::None;
}
diff --git a/source/slang/slang-ir-spirv-snippet.h b/source/slang/slang-ir-spirv-snippet.h
index c1497496d..4a509a230 100644
--- a/source/slang/slang-ir-spirv-snippet.h
+++ b/source/slang/slang-ir-spirv-snippet.h
@@ -64,6 +64,8 @@ struct SpvSnippet : public RefObject
None,
Int,
UInt,
+ UInt16,
+ Half,
Float,
Double,
FloatOrDouble, // Float or double type, depending on the result type of the intrinsic.
@@ -83,6 +85,7 @@ struct SpvSnippet : public RefObject
{
switch (type)
{
+ case ASMType::Half:
case ASMType::Float:
case ASMType::Double:
case ASMType::Float2:
@@ -102,6 +105,7 @@ struct SpvSnippet : public RefObject
return false;
switch (type)
{
+ case ASMType::Half:
case ASMType::Float:
case ASMType::Double:
case ASMType::FloatOrDouble:
@@ -112,6 +116,7 @@ struct SpvSnippet : public RefObject
case ASMType::Int:
return intValues[0] == other.intValues[0];
case ASMType::UInt:
+ case ASMType::UInt16:
return intValues[0] == other.intValues[0];
case ASMType::UInt2:
return intValues[0] == other.intValues[0] && intValues[1] == other.intValues[1];
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index f96cc174c..467580c83 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -998,6 +998,13 @@ void resetScratchDataBit(IRInst* inst, int bitIndex)
}
}
+IRVarLayout* findVarLayout(IRInst* value)
+{
+ if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>())
+ return as<IRVarLayout>(layoutDecoration->getLayout());
+ return nullptr;
+}
+
UnownedStringSlice getBasicTypeNameHint(IRType* basicType)
{
switch (basicType->getOp())
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 57a6c7c92..c107ec24a 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -224,6 +224,8 @@ bool isOne(IRInst* inst);
void initializeScratchData(IRInst* inst);
void resetScratchDataBit(IRInst* inst, int bitIndex);
+IRVarLayout* findVarLayout(IRInst* value);
+
// Run an operation over every block in a module
template<typename F>
static void overAllBlocks(IRModule* module, F f)
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index bf1ce1956..d0948fbb1 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -336,16 +336,18 @@ namespace Slang
validateIRInstOperands(context, inst);
context->seenInsts.add(inst);
+ if (auto code = as<IRGlobalValueWithCode>(inst))
+ {
+ context->domTree = computeDominatorTree(code);
+ validateCodeBody(context, code);
+ }
+
// If `inst` is itself a parent instruction, then we need to recursively
// validate its children.
validateIRInstChildren(context, inst);
if (auto code = as<IRGlobalValueWithCode>(inst))
- {
- context->domTree = computeDominatorTree(code);
- validateCodeBody(context, code);
context->domTree = nullptr;
- }
}
void validateIRInst(IRInst* inst)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 8d36c2e86..181970632 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3771,6 +3771,32 @@ namespace Slang
}
if (targetVectorType->getElementCount() != sourceVectorType->getElementCount())
{
+ auto fromCount = as<IRIntLit>(sourceVectorType->getElementCount());
+ auto toCount = as<IRIntLit>(targetVectorType->getElementCount());
+ if (fromCount && toCount)
+ {
+ if (toCount->getValue() < fromCount->getValue())
+ {
+ List<UInt> indices;
+ for (UInt i = 0; i < (UInt)toCount->getValue(); i++)
+ indices.add(i);
+ return emitSwizzle(targetVectorType, value, (UInt)indices.getCount(), indices.getBuffer());
+ }
+ else if (toCount->getValue() > fromCount->getValue())
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < (UInt)fromCount->getValue(); i++)
+ {
+ auto element = emitSwizzle(sourceVectorType->getElementType(), value , 1, &i);
+ args.add(element);
+ }
+ for (IRIntegerValue i = fromCount->getValue(); i < toCount->getValue(); i++)
+ {
+ args.add(emitDefaultConstruct(targetVectorType->getElementType()));
+ }
+ return emitMakeVector(targetVectorType, args);
+ }
+ }
auto reshape = emitIntrinsicInst(
getVectorType(
sourceVectorType->getElementType(), targetVectorType->getElementCount()),
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index 4ec0ac64f..ae1a0e9c2 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -60,6 +60,7 @@ static const char* hlslSemanticNames[] = {
"SV_IsFrontFace",
"SV_OutputControlPointID",
"SV_Position",
+ "SV_PointSize",
"SV_PrimitiveID",
"SV_RenderTargetArrayIndex",
"SV_SampleIndex",
diff --git a/tests/expected-failure.txt b/tests/expected-failure.txt
index aa00e810b..d5934a3c5 100644
--- a/tests/expected-failure.txt
+++ b/tests/expected-failure.txt
@@ -1,62 +1,29 @@
tests/autodiff/global-param-hoisting.slang.1 (vk)
-tests/autodiff/high-order-builtins-2.slang.2 (vk)
-tests/autodiff/matrix-arithmetic-fwd.slang.1 (vk)
-tests/autodiff/reverse-switch-case.slang.1 (vk)
-tests/autodiff/bsdf/bsdf-sample.slang (vk)
tests/bugs/atomic-coerce.slang.1 (vk)
tests/bugs/bool-op.slang.1 (vk)
tests/bugs/buffer-swizzle-store.slang.1 (vk)
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
tests/bugs/gh-3075.slang.2 (vk)
-tests/bugs/glsl-static-const-array.slang (vk)
-tests/bugs/matrix-reshape.slang.1 (vk)
-tests/bugs/nested-switch.slang.1 (vk)
-tests/bugs/parameter-block-load.slang (vk)
-tests/bugs/parens-cast-issue.slang.1 (vk)
tests/bugs/ray-query-in-generic.slang.1 (vk)
-tests/bugs/string-inline.slang.4 (vk)
tests/bugs/vec-compare.slang.2 (vk)
-tests/bugs/vec-init.slang.2 (vk)
-tests/bugs/inlining/global-const-inline.slang.1 (vk)
-tests/compute/buffer-layout.slang.2 (vk)
-tests/compute/dynamic-dispatch-16.slang (vk)
-tests/compute/dynamic-dispatch-17.slang (vk)
-tests/compute/dynamic-dispatch-18.slang.2 (vk)
-tests/compute/entry-point-uniform-params.slang.2 (vk)
-tests/compute/frem.slang.2 (vk)
-tests/compute/func-cbuffer-param.slang.2 (vk)
tests/compute/half-rw-texture-convert.slang.4 (vk)
tests/compute/half-rw-texture-convert2.slang.4 (vk)
tests/compute/half-vector-compare.slang.1 (vk)
-tests/compute/interface-shader-param-in-struct.slang.2 (vk)
tests/compute/loop-unroll.slang.5 (vk)
-tests/compute/pack-any-value-16bit.slang (vk)
-tests/compute/parameter-block.slang.2 (vk)
tests/compute/ray-tracing-inline.slang.1 (vk)
tests/compute/rw-texture-simple.slang.4 (vk)
-tests/compute/semantic.slang.3 (vk)
-tests/compute/static-const-array.slang.1 (vk)
tests/compute/static-const-matrix-array.slang.1 (vk)
-tests/compute/static-const-vector-array.slang.1 (vk)
tests/compute/texture-sample-grad-offset-clamp.slang (vk)
tests/compute/texture-simple.slang.4 (vk)
tests/compute/texture-simpler.slang (vk)
-tests/compute/vector-scalar-compare.slang.1 (vk)
-tests/cross-compile/glsl-bool-ops.slang.1 (vk)
tests/hlsl/glsl-matrix-layout.slang (vk)
-tests/hlsl/packoffset.slang.1 (vk)
-tests/hlsl-intrinsic/asfloat16.slang.3 (vk)
tests/hlsl-intrinsic/bit-cast-double.slang.3 (vk)
tests/hlsl-intrinsic/classify-double.slang.3 (vk)
tests/hlsl-intrinsic/classify-float.slang.3 (vk)
-tests/hlsl-intrinsic/f16tof32.slang.3 (vk)
-tests/hlsl-intrinsic/f32tof16.slang.3 (vk)
-tests/hlsl-intrinsic/literal-int64.slang.4 (vk)
tests/hlsl-intrinsic/scalar-double-d3d-intrinsic.slang.4 (vk)
tests/hlsl-intrinsic/scalar-double-simple.slang.4 (vk)
tests/hlsl-intrinsic/scalar-double-vk-intrinsic.slang.1 (vk)
tests/hlsl-intrinsic/scalar-float.slang.3 (vk)
-tests/hlsl-intrinsic/scalar-int64.slang.4 (vk)
tests/hlsl-intrinsic/scalar-uint.slang.3 (vk)
tests/hlsl-intrinsic/scalar-uint64.slang.4 (vk)
tests/hlsl-intrinsic/vector-double-reduced-intrinsic.slang.3 (vk)
@@ -76,9 +43,6 @@ tests/hlsl-intrinsic/wave-shuffle-vk.slang.3 (vk)
tests/hlsl-intrinsic/wave-vector.slang.3 (vk)
tests/hlsl-intrinsic/wave.slang.3 (vk)
tests/hlsl-intrinsic/active-mask/switch.slang.3 (vk)
-tests/hlsl-intrinsic/bit-cast/bit-cast-16-bit.slang.1 (vk)
-tests/hlsl-intrinsic/byte-address-buffer/byte-address-16bit-vector.slang.2 (vk)
-tests/hlsl-intrinsic/byte-address-buffer/byte-address-16bit.slang.2 (vk)
tests/hlsl-intrinsic/size-of/align-of-3.slang.3 (vk)
tests/hlsl-intrinsic/size-of/size-of-3.slang.3 (vk)
tests/hlsl-intrinsic/wave-mask/wave-active-product.slang.3 (vk)
@@ -94,12 +58,7 @@ tests/hlsl-intrinsic/wave-mask/wave-read-lane-at-vk.slang.1 (vk)
tests/hlsl-intrinsic/wave-mask/wave-shuffle-vk.slang.3 (vk)
tests/hlsl-intrinsic/wave-mask/wave-vector.slang.3 (vk)
tests/hlsl-intrinsic/wave-mask/wave.slang.3 (vk)
-tests/ir/string-literal-hash.slang.1 (vk)
tests/language-feature/constants/constexpr-loop.slang.1 (vk)
-tests/language-feature/initializer-lists/default-init-16bit-types.slang (vk)
-tests/language-feature/shader-params/interface-shader-param-ordinary.slang.2 (vk)
-tests/language-feature/swizzles/matrix-swizzle-write-array.slang.1 (vk)
-tests/language-feature/swizzles/matrix-swizzle-write-swizzle.slang.1 (vk)
tests/optimization/func-resource-result/func-resource-result-complex.slang.1 (vk)
tests/slang-extension/atomic-float-byte-address-buffer.slang.2 (vk)
tests/slang-extension/atomic-int64-byte-address-buffer.slang.4 (vk)
@@ -107,5 +66,4 @@ tests/slang-extension/atomic-min-max-u64-byte-address-buffer.slang.4 (vk)
tests/slang-extension/cas-int64-byte-address-buffer.slang.4 (vk)
tests/slang-extension/exchange-int64-byte-address-buffer.slang.4 (vk)
tests/slang-extension/realtime-clock.slang.2 (vk)
-tests/spirv/spirv-instruction.slang (vk)
tests/type/texture-sampler/texture-sampler-2d.slang (vk) \ No newline at end of file
diff --git a/tests/vkray/raygen.slang.glsl b/tests/vkray/raygen.slang.glsl
index 28bd5956b..69dc74c53 100644
--- a/tests/vkray/raygen.slang.glsl
+++ b/tests/vkray/raygen.slang.glsl
@@ -2,24 +2,6 @@
#extension GL_EXT_ray_tracing : require
layout(row_major) uniform;
layout(row_major) buffer;
-struct ReflectionRay_0
-{
- float color_1;
-};
-
-layout(location = 0)
-rayPayloadEXT
-ReflectionRay_0 p_0;
-
-struct ShadowRay_0
-{
- float hitDistance_0;
-};
-
-layout(location = 1)
-rayPayloadEXT
-ShadowRay_0 p_1;
-
layout(binding = 0)
uniform texture2D samplerPosition_0;
@@ -49,7 +31,32 @@ layout(std140) uniform _S1
vec4 viewPos_0;
mat4x4 view_0;
mat4x4 model_0;
-} ubo_0;
+}ubo_0;
+layout(binding = 5)
+uniform accelerationStructureEXT as_0;
+
+layout(rgba32f)
+layout(binding = 4)
+uniform image2D outputImage_0;
+
+struct ReflectionRay_0
+{
+ float color_1;
+};
+
+layout(location = 0)
+rayPayloadEXT
+ReflectionRay_0 p_0;
+
+struct ShadowRay_0
+{
+ float hitDistance_0;
+};
+
+layout(location = 1)
+rayPayloadEXT
+ShadowRay_0 p_1;
+
struct RayDesc_0
{
vec3 Origin_0;
@@ -74,18 +81,11 @@ void TraceRay_1(accelerationStructureEXT AccelerationStructure_1, uint RayFlags_
return;
}
-layout(binding = 5)
-uniform accelerationStructureEXT as_0;
-
float saturate_0(float x_0)
{
return clamp(x_0, 0.0, 1.0);
}
-layout(rgba32f)
-layout(binding = 4)
-uniform image2D outputImage_0;
-
void main()
{
uvec3 _S2 = ((gl_LaunchIDEXT));