diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 32 | ||||
| -rw-r--r-- | source/slang/slang-emit-glsl.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-array-reg-to-mem.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 6 | ||||
| -rw-r--r-- | tests/spirv/mutating-method-syn.slang | 38 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-api.h | 3 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-device.cpp | 14 |
12 files changed, 208 insertions, 27 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 351d5a9cc..1e3c6a361 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1836,18 +1836,20 @@ namespace Slang RefPtr<WitnessTable> witnessTable) { if(satisfyingMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>() - && !requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>()) + != requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>()) { - // A `[mutating]` method can't satisfy a non-`[mutating]` requirement, - // but vice-versa is okay. + // A `[mutating]` method can't satisfy a non-`[mutating]` requirement. + // The opposite direction is okay, but we will need to synthesize a wrapper + // to ensure type matches, so we will return false here either way. return false; } if (satisfyingMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>() - && !requiredMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>()) + != requiredMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>()) { - // A `[constref]` method can't satisfy a non-`[constref]` requirement, - // but vice-versa is okay. + // A `[constref]` method can't satisfy a non-`[constref]` requirement. + // The opposite direction is okay, but we will need to synthesize a wrapper + // to ensure type matches, so we will return false here either way. return false; } @@ -2677,11 +2679,21 @@ namespace Slang synArg->type = paramType; synArgs.add(synArg); - if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>()) + // Add modifiers + for (auto modifier : paramDeclRef.getDecl()->modifiers) { - auto noDiffModifier = m_astBuilder->create<NoDiffModifier>(); - noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); - addModifier(synParamDecl, noDiffModifier); + if (as<NoDiffModifier>(modifier)) + { + auto noDiffModifier = m_astBuilder->create<NoDiffModifier>(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(synParamDecl, noDiffModifier); + } + else if (as<InOutModifier>(modifier) || as<OutModifier>(modifier) || as<ConstRefModifier>(modifier) || as<RefModifier>(modifier)) + { + auto clonedModifier = (Modifier*)m_astBuilder->createByNodeType(modifier->astNodeType); + clonedModifier->keywordName = modifier->keywordName; + addModifier(synParamDecl, clonedModifier); + } } } diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index e9472bc96..ccb296e84 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1893,6 +1893,18 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return true; } + case kIROp_GetVulkanRayTracingPayloadLocation: + { + auto payloadVar = inst->getOperand(0); + IRInst* location = getVulkanPayloadLocation(payloadVar); + if (!location) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), inst, "no payload location assigned."); + m_writer->emit("0"); + } + m_writer->emit(getIntVal(location)); + return true; + } case kIROp_ImageLoad: { m_writer->emit("imageLoad("); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 104fab339..62e7d428b 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2274,6 +2274,20 @@ struct SPIRVEmitContext registerInst(inst, inner); return inner; } + case kIROp_GetVulkanRayTracingPayloadLocation: + { + IRInst* location = getVulkanPayloadLocation(inst->getOperand(0)); + if (!location) + { + SLANG_DIAGNOSE_UNEXPECTED(m_sink, inst, "no payload location assigned."); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + location = builder.getIntValue(builder.getIntType(), 0); + } + auto inner = ensureInst(location); + registerInst(inst, inner); + return inner; + } case kIROp_Return: if (as<IRReturn>(inst)->getVal()->getOp() == kIROp_VoidLit) return emitOpReturn(parent, inst); diff --git a/source/slang/slang-ir-array-reg-to-mem.cpp b/source/slang/slang-ir-array-reg-to-mem.cpp index 6f749f242..34bd5b148 100644 --- a/source/slang/slang-ir-array-reg-to-mem.cpp +++ b/source/slang/slang-ir-array-reg-to-mem.cpp @@ -33,9 +33,9 @@ namespace Slang if (auto arrayType = as<IRArrayTypeBase>(param->getFullType())) { changed = true; - builder.setInsertBefore(param); auto ptrArrayType = builder.getPtrType(arrayType); - auto newParam = builder.emitParam(ptrArrayType); + auto newParam = builder.createParam(ptrArrayType); + newParam->insertBefore(param); setInsertAfterOrdinaryInst(&builder, param); auto regVal = builder.emitLoad(newParam); param->replaceUsesWith(regVal); @@ -62,6 +62,7 @@ namespace Slang for (auto paramId : arrayParamIds) { auto arg = call->getArg(paramId); + SLANG_ASSERT(as<IRPtrTypeBase>(paramTypes[paramId])); auto var = builder.emitVar(as<IRPtrTypeBase>(paramTypes[paramId])->getValueType()); builder.emitStore(var, arg); call->setArg(paramId, var); diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index c4a71a3e9..956685d85 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2693,18 +2693,6 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module) } } end:; - if (location) - { - traverseUses(globalVar, [&](IRUse* use) - { - auto user = use->getUser(); - if (user->getOp() == kIROp_GetVulkanRayTracingPayloadLocation) - { - user->replaceUsesWith(location); - user->removeAndDeallocate(); - } - }); - } } } diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 5d4673981..b92f5e910 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -265,7 +265,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return false; } - void inferTextureFormat(IRInst* textureInst, IRTextureTypeBase* textureType) + static IRType* replaceImageElementType(IRInst* originalType, IRInst* newElementType) + { + switch(originalType->getOp()) + { + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + case kIROp_InOutType: + { + auto newInnerType = replaceImageElementType(originalType->getOperand(0), newElementType); + if (newInnerType != originalType->getOperand(0)) + { + IRBuilder builder(originalType); + builder.setInsertBefore(originalType); + IRCloneEnv cloneEnv; + cloneEnv.mapOldValToNew.add(originalType->getOperand(0), newInnerType); + return (IRType*)cloneInst(&cloneEnv, &builder, originalType); + } + return (IRType*)originalType; + } + + default: + if (as<IRResourceTypeBase>(originalType)) + return (IRType*)newElementType; + return (IRType*)originalType; + } + } + + static void inferTextureFormat(IRInst* textureInst, IRTextureTypeBase* textureType) { ImageFormat format = ImageFormat::unknown; if (auto decor = textureInst->findDecoration<IRFormatDecoration>()) @@ -368,14 +399,52 @@ struct SPIRVLegalizationContext : public SourceEmitterBase args.add(builder.getIntValue(builder.getUIntType(), IRIntegerValue(format))); auto newType = (IRType*)builder.emitIntrinsicInst(builder.getTypeKind(), textureType->getOp(), 3, args.getBuffer()); - textureInst->setFullType(newType); + if (textureInst->getFullType() == textureType) + { + // Simple texture typed global param. + textureInst->setFullType(newType); + } + else + { + // Array typed global param. We need to replace the type and the types of all getElement insts. + auto newInstType = (IRType*)replaceImageElementType(textureInst->getFullType(), newType); + textureInst->setFullType(newInstType); + List<IRUse*> typeReplacementWorkList; + HashSet<IRUse*> typeReplacementWorkListSet; + for (auto use = textureInst->firstUse; use; use = use->nextUse) + { + if (typeReplacementWorkListSet.add(use)) + typeReplacementWorkList.add(use); + } + for (Index i = 0; i < typeReplacementWorkList.getCount(); i++) + { + auto use = typeReplacementWorkList[i]; + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_GetElement: + case kIROp_Load: + { + auto newUserType = (IRType*)replaceImageElementType(user->getFullType(), newType); + user->setFullType(newUserType); + for (auto u = user->firstUse; u; u = u->nextUse) + { + if (typeReplacementWorkListSet.add(u)) + typeReplacementWorkList.add(u); + } + break; + }; + } + } + } } } void processGlobalParam(IRGlobalParam* inst) { // If the param is a texture, infer its format. - if (auto textureType = as<IRTextureTypeBase>(inst->getDataType())) + if (auto textureType = as<IRTextureTypeBase>(unwrapArray(inst->getDataType()))) { inferTextureFormat(inst, textureType); } @@ -1447,6 +1516,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void processModule() { +#if 0 + eliminateArrayTypeSSARegisters(m_module); +#endif + // Process global params before anything else, so we don't generate inefficient // array marhalling code for array-typed global params. for (auto globalInst : m_module->getGlobalInsts()) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index b4a41f8a5..5ecbc8121 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -814,6 +814,24 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) return nullptr; } +IRInst* getVulkanPayloadLocation(IRInst* payloadGlobalVar) +{ + IRInst* location = nullptr; + for (auto decor : payloadGlobalVar->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_VulkanRayPayloadDecoration: + case kIROp_VulkanCallablePayloadDecoration: + case kIROp_VulkanHitObjectAttributesDecoration: + return decor->getOperand(0); + default: + continue; + } + } + return location; +} + void moveParams(IRBlock* dest, IRBlock* src) { for (auto param = src->getFirstChild(); param;) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 82ce1344c..c13ce1931 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -209,6 +209,8 @@ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key); +IRInst* getVulkanPayloadLocation(IRInst* payloadGlobalVar); + void moveParams(IRBlock* dest, IRBlock* src); void removePhiArgs(IRInst* phiParam); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 065263a59..f62d4b24a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9903,7 +9903,13 @@ static void lowerFrontEndEntryPointToIR( instToDecorate = findGenericReturnVal(irGeneric); } + // If the entry-point decorations has already been created (because the user + // specified duplicate entries in the entry point list), we can stop now. + if (instToDecorate->findDecoration<IREntryPointDecoration>()) + return; + { + Name* entryPointName = entryPoint->getFuncDecl()->getName(); builder->addEntryPointDecoration(instToDecorate, entryPoint->getProfile(), entryPointName->text.getUnownedSlice(), moduleName.getUnownedSlice()); } diff --git a/tests/spirv/mutating-method-syn.slang b/tests/spirv/mutating-method-syn.slang new file mode 100644 index 000000000..14adc001f --- /dev/null +++ b/tests/spirv/mutating-method-syn.slang @@ -0,0 +1,38 @@ +// mutating-method-syn.slang +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute +// Test ability to directly output SPIR-V + +interface IFoo +{ + [mutating] + int bar(inout int y); +} + +struct Val : IFoo +{ + int x; + + int bar(int y) + { + return x + y; + } +} + +int test<T:IFoo>(inout T f, inout int y) +{ + return f.bar(y); +} + +//TEST_INPUT:set result = out ubuffer(data=[0 0 0 0], stride=4) + +RWStructuredBuffer<int> result; +[numthreads(1,1,1)] +void computeMain() +{ + Val v; + int y = 0; + v.x = 1; + + // CHECK: 1 + result[0] = test(v, y); +} diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h index d20cd555c..b7cbf13de 100644 --- a/tools/gfx/vulkan/vk-api.h +++ b/tools/gfx/vulkan/vk-api.h @@ -261,6 +261,9 @@ struct VulkanExtendedFeatureProperties VkPhysicalDeviceRobustness2FeaturesEXT robustness2Features = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ROBUSTNESS_2_FEATURES_EXT}; + VkPhysicalDeviceRayTracingInvocationReorderFeaturesNV rayTracingInvocationReorderFeatures = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_INVOCATION_REORDER_FEATURES_NV}; + // Clock features VkPhysicalDeviceShaderClockFeaturesKHR clockFeatures = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_CLOCK_FEATURES_KHR diff --git a/tools/gfx/vulkan/vk-device.cpp b/tools/gfx/vulkan/vk-device.cpp index 4fbf987e9..bed9c038a 100644 --- a/tools/gfx/vulkan/vk-device.cpp +++ b/tools/gfx/vulkan/vk-device.cpp @@ -405,6 +405,10 @@ Result DeviceImpl::initVulkanInstanceAndDevice( extendedFeatures.rayTracingPipelineFeatures.pNext = deviceFeatures2.pNext; deviceFeatures2.pNext = &extendedFeatures.rayTracingPipelineFeatures; + // SER features. + extendedFeatures.rayTracingInvocationReorderFeatures.pNext = deviceFeatures2.pNext; + deviceFeatures2.pNext = &extendedFeatures.rayTracingInvocationReorderFeatures; + // Acceleration structure features extendedFeatures.accelerationStructureFeatures.pNext = deviceFeatures2.pNext; deviceFeatures2.pNext = &extendedFeatures.accelerationStructureFeatures; @@ -582,6 +586,16 @@ Result DeviceImpl::initVulkanInstanceAndDevice( m_features.add("mesh-shader"); } + if (extendedFeatures.rayTracingInvocationReorderFeatures.rayTracingInvocationReorder) + { + deviceExtensions.add(VK_NV_RAY_TRACING_INVOCATION_REORDER_EXTENSION_NAME); + + extendedFeatures.rayTracingInvocationReorderFeatures.pNext = (void*)deviceCreateInfo.pNext; + deviceCreateInfo.pNext = &extendedFeatures.rayTracingInvocationReorderFeatures; + + m_features.add("shader-execution-reorder"); + } + if (_hasAnySetBits( extendedFeatures.vulkan12Features, offsetof(VkPhysicalDeviceVulkan12Features, pNext) + sizeof(void*))) |
