diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2018-12-14 11:00:02 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-12-14 11:00:02 -0800 |
| commit | ec745c032a8dc16c3e689458c20541a4e7aa64d6 (patch) | |
| tree | 69b4455cbaedf0ab4c887798aada1929962a7a53 | |
| parent | 11793edf25a4907fe396d69fd3cdddaee3d421d5 (diff) | |
Represent global shader parameters explicitly in the IR (#756)
Before this change, global shader parameters were represented in the IR as just being ordinary global variables.
The only indication that a particular global represented a parameter was when it got a layotu attached to it (as part of back-end processing), and we've had a number of bugs related to layouts being dropped so that what should have been a shader parameter turned into an ordinary global variable in the output.
This change is more strongly motivated by the fact that making shader parameters look like globals means that we cannot easily reason about their value when doing IR transformations.
If we see two `load`s from the same global variable can we assume they yield the same value?
In the general case we cannot, and this means that any transformation that wants to rely on the fact that an input `Texture2D` shader parameter can't actually change over the life of the program needs to do extra work.
The fix here is to introduce a new kind of IR instruction that represents a global shader parameter directly (not a pointer to it as a global would), at which point there isn't even such a notion as a "load" from the parameter, since it represents the value directly.
In several cases logic that used to apply to global variables in case they were shader parameters (by looking for a layout) is now moved to apply to these global parameters.
The biggest source of issues in this change was that switching from pointers to plain values to represent these shader parameters stresses different cases in type legalization. I also had to deal with the case of legalization for GLSL where we actually *do* need global shader parameters that are writable (since varying output goes in the global scope), but in that case I borrowed the use of pointer-like `Out<...>` and `InOut<...>` types to represent that intent, which we were already using for function parameters representing outputs.
A few tests started failing because the changes lead to a slightly different order of code emission, which in some HLSL tests resulted in a function parameter named `s` getting emitted before a global parameter named `s`, leading to the latter getting the name `s_1` instead of `s_0`.
A few SPIR-V tests started failing because the new approach means that we no longer end up performing a load from all varying input parameters at the start of `main` and instead reference the varying inputs directly. The resulting code is more idomatic, but it differed from the baselines for those tests.
| -rw-r--r-- | source/slang/check.h | 7 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 110 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 8 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 292 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 180 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 1 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 3 | ||||
| -rw-r--r-- | tests/bindings/array-of-struct-of-resource.hlsl | 2 | ||||
| -rw-r--r-- | tests/bindings/binding0.hlsl | 2 | ||||
| -rw-r--r-- | tests/cross-compile/non-uniform-indexing.slang.glsl | 8 | ||||
| -rw-r--r-- | tests/vkray/anyhit.slang.glsl | 4 | ||||
| -rw-r--r-- | tests/vkray/closesthit.slang.glsl | 19 |
14 files changed, 496 insertions, 176 deletions
diff --git a/source/slang/check.h b/source/slang/check.h new file mode 100644 index 000000000..1f378ec7b --- /dev/null +++ b/source/slang/check.h @@ -0,0 +1,7 @@ +// check.h +#pragma once + +namespace Slang +{ + bool isGlobalShaderParameter(VarDeclBase* decl); +}
\ No newline at end of file diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 930520892..39616b3df 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -2355,6 +2355,7 @@ struct EmitVisitor case kIROp_Var: case kIROp_GlobalVar: case kIROp_GlobalConstant: + case kIROp_GlobalParam: case kIROp_Param: return false; @@ -5558,7 +5559,7 @@ struct EmitVisitor void emitHLSLParameterGroup( EmitContext* ctx, - IRGlobalVar* varDecl, + IRGlobalParam* varDecl, IRUniformParameterGroupType* type) { if(as<IRTextureBufferType>(type)) @@ -5623,7 +5624,7 @@ struct EmitVisitor void emitGLSLParameterGroup( EmitContext* ctx, - IRGlobalVar* varDecl, + IRGlobalParam* varDecl, IRUniformParameterGroupType* type) { auto varLayout = getVarLayout(ctx, varDecl); @@ -5680,14 +5681,14 @@ struct EmitVisitor // If the underlying variable was an array (or array of arrays, etc.) // we need to emit all those array brackets here. - emitArrayBrackets(ctx, varDecl->getDataType()->getValueType()); + emitArrayBrackets(ctx, varDecl->getDataType()); emit(";\n"); } void emitIRParameterGroup( EmitContext* ctx, - IRGlobalVar* varDecl, + IRGlobalParam* varDecl, IRUniformParameterGroupType* type) { switch (ctx->shared->target) @@ -5763,7 +5764,7 @@ struct EmitVisitor void emitIRStructuredBuffer_GLSL( EmitContext* ctx, - IRGlobalVar* varDecl, + IRGlobalParam* varDecl, IRHLSLStructuredBufferTypeBase* structuredBufferType) { // Shader storage buffer is an OpenGL 430 feature @@ -5809,14 +5810,14 @@ struct EmitVisitor emit("} "); emit(getIRName(varDecl)); - emitArrayBrackets(ctx, varDecl->getDataType()->getValueType()); + emitArrayBrackets(ctx, varDecl->getDataType()); emit(";\n"); } void emitIRByteAddressBuffer_GLSL( EmitContext* ctx, - IRGlobalVar* varDecl, + IRGlobalParam* varDecl, IRByteAddressBufferTypeBase* /* byteAddressBufferType */) { // TODO: A lot of this logic is copy-pasted from `emitIRStructuredBuffer_GLSL`. @@ -5862,7 +5863,7 @@ struct EmitVisitor emit("} "); emit(getIRName(varDecl)); - emitArrayBrackets(ctx, varDecl->getDataType()->getValueType()); + emitArrayBrackets(ctx, varDecl->getDataType()); emit(";\n"); } @@ -5894,6 +5895,63 @@ struct EmitVisitor Emit("}\n"); } + // An ordinary global variable won't have a layout + // associated with it, since it is not a shader + // parameter. + // + SLANG_ASSERT(!getVarLayout(ctx, varDecl)); + VarLayout* layout = nullptr; + + // An ordinary global variable (which is not a + // shader parameter) may need special + // modifiers to indicate it as such. + // + switch (getTarget(ctx)) + { + case CodeGenTarget::HLSL: + // HLSL requires the `static` modifier on any + // global variables; otherwise they are assumed + // to be uniforms. + Emit("static "); + break; + + default: + break; + } + + emitIRVarModifiers(ctx, layout, varDecl, varType); + + emitIRRateQualifiers(ctx, varDecl); + emitIRType(ctx, varType, getIRName(varDecl)); + + // TODO: These shouldn't be needed for ordinary + // global variables. + // + emitIRSemantics(ctx, varDecl); + emitIRLayoutSemantics(ctx, varDecl); + + if (varDecl->getFirstBlock()) + { + Emit(" = "); + emit(initFuncName); + Emit("()"); + } + + emit(";\n\n"); + } + + void emitIRGlobalParam( + EmitContext* ctx, + IRGlobalParam* varDecl) + { + auto rawType = varDecl->getDataType(); + + auto varType = rawType; + if( auto outType = as<IROutTypeBase>(varType) ) + { + varType = outType->getValueType(); + } + // When a global shader parameter represents a "parameter group" // (either a constant buffer or a parameter block with non-resource // data in it), we will prefer to emit it as an ordinary `cbuffer` @@ -5985,26 +6043,11 @@ struct EmitVisitor // Need to emit appropriate modifiers here. + // We expect/require all shader parameters to + // have some kind of layout information associted with them. + // auto layout = getVarLayout(ctx, varDecl); - - if (!layout) - { - // A global variable without a layout is just an - // ordinary global variable, and may need special - // modifiers to indicate it as such. - switch (getTarget(ctx)) - { - case CodeGenTarget::HLSL: - // HLSL requires the `static` modifier on any - // global variables; otherwise they are assumed - // to be uniforms. - Emit("static "); - break; - - default: - break; - } - } + SLANG_ASSERT(layout); emitIRVarModifiers(ctx, layout, varDecl, varType); @@ -6015,16 +6058,13 @@ struct EmitVisitor emitIRLayoutSemantics(ctx, varDecl); - if (varDecl->getFirstBlock()) - { - Emit(" = "); - emit(initFuncName); - Emit("()"); - } + // A shader parameter cannot have an initializer, + // so we do need to consider emitting one here. emit(";\n\n"); } + void emitIRGlobalConstantInitializer( EmitContext* ctx, IRGlobalConstant* valDecl) @@ -6098,6 +6138,10 @@ struct EmitVisitor emitIRGlobalVar(ctx, (IRGlobalVar*) inst); break; + case kIROp_GlobalParam: + emitIRGlobalParam(ctx, (IRGlobalParam*) inst); + break; + case kIROp_GlobalConstant: emitIRGlobalConstant(ctx, (IRGlobalConstant*) inst); break; diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 35a73c2f1..69940a79d 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -163,6 +163,8 @@ INST_RANGE(Type, VoidType, StructType) INST(GlobalConstant, global_constant, 0, 0) INST_RANGE(GlobalValueWithCode, Func, GlobalConstant) +INST(GlobalParam, global_param, 0, 0) + INST(StructKey, key, 0, 0) INST(GlobalGenericParam, global_generic_param, 0, 0) INST(WitnessTable, witness_table, 0, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 1dfd52d8a..737675d87 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -558,6 +558,12 @@ struct IRGlobalConstant : IRGlobalValueWithCode IR_LEAF_ISA(GlobalConstant) }; +struct IRGlobalParam : IRInst +{ + IR_LEAF_ISA(GlobalParam) +}; + + // An entry in a witness table (see below) struct IRWitnessTableEntry : IRInst { @@ -798,6 +804,8 @@ struct IRBuilder IRType* valueType); IRGlobalConstant* createGlobalConstant( IRType* valueType); + IRGlobalParam* createGlobalParam( + IRType* valueType); IRWitnessTable* createWitnessTable(); IRWitnessTableEntry* createWitnessTableEntry( IRWitnessTable* witnessTable, diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 901d8705b..a97cc0393 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -335,31 +335,31 @@ static LegalVal legalizeStore( } } -static LegalVal legalizeFieldAddress( - IRTypeLegalizationContext* context, +static LegalVal legalizeFieldExtract( + IRTypeLegalizationContext* context, LegalType type, - LegalVal legalPtrOperand, + LegalVal legalStructOperand, IRStructKey* fieldKey) { auto builder = context->builder; - switch (legalPtrOperand.flavor) + switch (legalStructOperand.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: return LegalVal::simple( - builder->emitFieldAddress( + builder->emitFieldExtract( type.getSimple(), - legalPtrOperand.getSimple(), + legalStructOperand.getSimple(), fieldKey)); case LegalVal::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. - auto pairVal = legalPtrOperand.getPair(); + auto pairVal = legalStructOperand.getPair(); auto pairInfo = pairVal->pairInfo; auto pairElement = pairInfo->findElement(fieldKey); if (!pairElement) @@ -387,7 +387,7 @@ static LegalVal legalizeFieldAddress( if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { - ordinaryVal = legalizeFieldAddress( + ordinaryVal = legalizeFieldExtract( context, ordinaryType, pairVal->ordinaryVal, @@ -396,7 +396,7 @@ static LegalVal legalizeFieldAddress( if (pairElement->flags & PairInfo::kFlag_hasSpecial) { - specialVal = legalizeFieldAddress( + specialVal = legalizeFieldExtract( context, specialType, pairVal->specialVal, @@ -413,7 +413,7 @@ static LegalVal legalizeFieldAddress( // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. - auto ptrTupleInfo = legalPtrOperand.getTuple(); + auto ptrTupleInfo = legalStructOperand.getTuple(); for (auto ee : ptrTupleInfo->elements) { if (ee.key == fieldKey) @@ -435,7 +435,7 @@ static LegalVal legalizeFieldAddress( } } -static LegalVal legalizeFieldAddress( +static LegalVal legalizeFieldExtract( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, @@ -445,38 +445,38 @@ static LegalVal legalizeFieldAddress( // the "field" argument. auto fieldKey = legalFieldOperand.getSimple(); - return legalizeFieldAddress( + return legalizeFieldExtract( context, type, legalPtrOperand, (IRStructKey*) fieldKey); } -static LegalVal legalizeFieldExtract( - IRTypeLegalizationContext* context, +static LegalVal legalizeFieldAddress( + IRTypeLegalizationContext* context, LegalType type, - LegalVal legalStructOperand, + LegalVal legalPtrOperand, IRStructKey* fieldKey) { auto builder = context->builder; - switch (legalStructOperand.flavor) + switch (legalPtrOperand.flavor) { case LegalVal::Flavor::none: return LegalVal(); case LegalVal::Flavor::simple: return LegalVal::simple( - builder->emitFieldExtract( + builder->emitFieldAddress( type.getSimple(), - legalStructOperand.getSimple(), + legalPtrOperand.getSimple(), fieldKey)); case LegalVal::Flavor::pair: { // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. - auto pairVal = legalStructOperand.getPair(); + auto pairVal = legalPtrOperand.getPair(); auto pairInfo = pairVal->pairInfo; auto pairElement = pairInfo->findElement(fieldKey); if (!pairElement) @@ -504,7 +504,7 @@ static LegalVal legalizeFieldExtract( if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { - ordinaryVal = legalizeFieldExtract( + ordinaryVal = legalizeFieldAddress( context, ordinaryType, pairVal->ordinaryVal, @@ -513,7 +513,7 @@ static LegalVal legalizeFieldExtract( if (pairElement->flags & PairInfo::kFlag_hasSpecial) { - specialVal = legalizeFieldExtract( + specialVal = legalizeFieldAddress( context, specialType, pairVal->specialVal, @@ -530,7 +530,7 @@ static LegalVal legalizeFieldExtract( // corresponding to a field. We will handle // this by simply returning the corresponding // element from the operand. - auto ptrTupleInfo = legalStructOperand.getTuple(); + auto ptrTupleInfo = legalPtrOperand.getTuple(); for (auto ee : ptrTupleInfo->elements) { if (ee.key == fieldKey) @@ -546,13 +546,27 @@ static LegalVal legalizeFieldExtract( UNREACHABLE_RETURN(LegalVal()); } + case LegalVal::Flavor::implicitDeref: + { + // The original value had a level of indirection + // that is now being removed, so should not be + // able to get at the *address* of the field any + // more, and need to resign ourselves to just + // getting at the field *value* and then + // adding an implicit dereference on top of that. + // + auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); + + return LegalVal::implicitDeref(legalizeFieldExtract(context,type, implicitDerefVal, fieldKey)); + } + default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); } } -static LegalVal legalizeFieldExtract( +static LegalVal legalizeFieldAddress( IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, @@ -562,13 +576,125 @@ static LegalVal legalizeFieldExtract( // the "field" argument. auto fieldKey = legalFieldOperand.getSimple(); - return legalizeFieldExtract( + return legalizeFieldAddress( context, type, legalPtrOperand, (IRStructKey*) fieldKey); } +static LegalVal legalizeGetElement( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + IRInst* indexOperand) +{ + auto builder = context->builder; + + switch (legalPtrOperand.flavor) + { + case LegalVal::Flavor::none: + return LegalVal(); + + case LegalVal::Flavor::simple: + return LegalVal::simple( + builder->emitElementExtract( + type.getSimple(), + legalPtrOperand.getSimple(), + indexOperand)); + + case LegalVal::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairVal = legalPtrOperand.getPair(); + auto pairInfo = pairVal->pairInfo; + + LegalType ordinaryType = type; + LegalType specialType = type; + if (type.flavor == LegalType::Flavor::pair) + { + auto pairType = type.getPair(); + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + } + + LegalVal ordinaryVal = legalizeGetElement( + context, + ordinaryType, + pairVal->ordinaryVal, + indexOperand); + + LegalVal specialVal = legalizeGetElement( + context, + specialType, + pairVal->specialVal, + indexOperand); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalVal::Flavor::tuple: + { + // The operand is a tuple of pointer-like + // values, we want to extract the element + // corresponding to a field. We will handle + // this by simply returning the corresponding + // element from the operand. + auto ptrTupleInfo = legalPtrOperand.getTuple(); + + RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal(); + + auto tupleType = type.getTuple(); + SLANG_ASSERT(tupleType); + + auto elemCount = ptrTupleInfo->elements.Count(); + SLANG_ASSERT(elemCount == tupleType->elements.Count()); + + for(UInt ee = 0; ee < elemCount; ++ee) + { + auto ptrElem = ptrTupleInfo->elements[ee]; + auto elemType = tupleType->elements[ee].type; + + TuplePseudoVal::Element resElem; + resElem.key = ptrElem.key; + resElem.val = legalizeGetElement( + context, + elemType, + ptrElem.val, + indexOperand); + + resTupleInfo->elements.Add(resElem); + } + + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + +static LegalVal legalizeGetElement( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalIndexOperand) +{ + // We don't expect any legalization to affect + // the "index" argument. + auto indexOperand = legalIndexOperand.getSimple(); + + return legalizeGetElement( + context, + type, + legalPtrOperand, + indexOperand); +} + + static LegalVal legalizeGetElementPtr( IRTypeLegalizationContext* context, LegalType type, @@ -657,6 +783,23 @@ static LegalVal legalizeGetElementPtr( return LegalVal::tuple(resTupleInfo); } + case LegalVal::Flavor::implicitDeref: + { + // The original value used to be a pointer to an array, + // and somebody is trying to get at an element pointer. + // Now we just have an array (wrapped with an implicit + // dereference) and need to just fetch the chosen element + // instead (and then wrapp the element value with an + // implicit dereference). + // + auto implicitDerefVal = legalPtrOperand.getImplicitDeref(); + return LegalVal::implicitDeref(legalizeGetElement( + context, + type, + implicitDerefVal, + indexOperand)); + } + default: SLANG_UNEXPECTED("unhandled"); UNREACHABLE_RETURN(LegalVal()); @@ -816,6 +959,9 @@ static LegalVal legalizeInst( case kIROp_FieldExtract: return legalizeFieldExtract(context, type, args[0], args[1]); + case kIROp_getElement: + return legalizeGetElement(context, type, args[0], args[1]); + case kIROp_getElementPtr: return legalizeGetElementPtr(context, type, args[0], args[1]); @@ -965,6 +1111,9 @@ static LegalVal legalizeGlobalConstant( IRTypeLegalizationContext* context, IRGlobalConstant* irGlobalConstant); +static LegalVal legalizeGlobalParam( + IRTypeLegalizationContext* context, + IRGlobalParam* irGlobalParam); static LegalVal legalizeInst( IRTypeLegalizationContext* context, @@ -992,6 +1141,9 @@ static LegalVal legalizeInst( case kIROp_GlobalConstant: return legalizeGlobalConstant(context, cast<IRGlobalConstant>(inst)); + case kIROp_GlobalParam: + return legalizeGlobalParam(context, cast<IRGlobalParam>(inst)); + default: break; } @@ -1184,6 +1336,28 @@ static LegalVal declareSimpleVar( } break; + case kIROp_GlobalConstant: + { + auto globalConst = builder->createGlobalConstant(type); + globalConst->removeFromParent(); + globalConst->insertBefore(context->insertBeforeGlobal); + + irVar = globalConst; + legalVarVal = LegalVal::simple(globalConst); + } + break; + + case kIROp_GlobalParam: + { + auto globalParam = builder->createGlobalParam(type); + globalParam->removeFromParent(); + globalParam->insertBefore(context->insertBeforeGlobal); + + irVar = globalParam; + legalVarVal = LegalVal::simple(globalParam); + } + break; + case kIROp_Var: { auto localVar = builder->emitVar(type); @@ -1355,9 +1529,6 @@ static LegalVal legalizeGlobalVar( context, irGlobalVar->getDataType()->getValueType()); - RefPtr<VarLayout> varLayout = findVarLayout(irGlobalVar); - RefPtr<TypeLayout> typeLayout = varLayout ? varLayout->typeLayout : nullptr; - switch (legalValueType.flavor) { case LegalType::Flavor::simple: @@ -1373,21 +1544,12 @@ static LegalVal legalizeGlobalVar( { context->insertBeforeGlobal = irGlobalVar->getNextInst(); - LegalVarChain* varChain = nullptr; - LegalVarChain varChainStorage; - if (varLayout) - { - varChainStorage.next = nullptr; - varChainStorage.varLayout = varLayout; - varChain = &varChainStorage; - } - IRGlobalNameInfo globalNameInfo; globalNameInfo.globalVar = irGlobalVar; globalNameInfo.counter = 0; UnownedStringSlice nameHint = findNameHint(irGlobalVar); - LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, typeLayout, varChain, nameHint, &globalNameInfo); + LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, nullptr, nameHint, &globalNameInfo); // Register the new value as the replacement for the old registerLegalizedValue(context, irGlobalVar, newVal); @@ -1445,6 +1607,62 @@ static LegalVal legalizeGlobalConstant( } } +static LegalVal legalizeGlobalParam( + IRTypeLegalizationContext* context, + IRGlobalParam* irGlobalParam) +{ + // Legalize the type for the variable's value + auto legalValueType = legalizeType( + context, + irGlobalParam->getFullType()); + + RefPtr<VarLayout> varLayout = findVarLayout(irGlobalParam); + RefPtr<TypeLayout> typeLayout = varLayout ? varLayout->typeLayout : nullptr; + + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + // Easy case: the type is usable as-is, and we + // should just do that. + irGlobalParam->setFullType(legalValueType.getSimple()); + return LegalVal::simple(irGlobalParam); + + default: + { + context->insertBeforeGlobal = irGlobalParam->getNextInst(); + + LegalVarChain* varChain = nullptr; + LegalVarChain varChainStorage; + if (varLayout) + { + varChainStorage.next = nullptr; + varChainStorage.varLayout = varLayout; + varChain = &varChainStorage; + } + + IRGlobalNameInfo globalNameInfo; + globalNameInfo.globalVar = irGlobalParam; + globalNameInfo.counter = 0; + + // TODO: need to handle initializer here! + + UnownedStringSlice nameHint = findNameHint(irGlobalParam); + LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, &globalNameInfo); + + // Register the new value as the replacement for the old + registerLegalizedValue(context, irGlobalParam, newVal); + + // Remove the old global from the module. + irGlobalParam->removeFromParent(); + context->replacedInstructions.Add(irGlobalParam); + + return newVal; + } + break; + } +} + + static void legalizeTypes( IRTypeLegalizationContext* context) { diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index b8892cc02..0d93957c8 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -1966,6 +1966,18 @@ namespace Slang return globalConstant; } + IRGlobalParam* IRBuilder::createGlobalParam( + IRType* valueType) + { + IRGlobalParam* inst = createInst<IRGlobalParam>( + this, + kIROp_GlobalParam, + valueType); + maybeSetSourceLoc(this, inst); + addGlobalValue(this, inst); + return inst; + } + IRWitnessTable* IRBuilder::createWitnessTable() { IRWitnessTable* witnessTable = createInst<IRWitnessTable>( @@ -3730,6 +3742,7 @@ namespace Slang case kIROp_Generic: case kIROp_GlobalVar: case kIROp_GlobalConstant: + case kIROp_GlobalParam: case kIROp_StructKey: case kIROp_GlobalGenericParam: case kIROp_WitnessTable: @@ -3800,7 +3813,7 @@ namespace Slang // Legalization of entry points for GLSL: // - IRGlobalVar* addGlobalVariable( + IRGlobalParam* addGlobalParam( IRModule* module, IRType* valueType) { @@ -3812,7 +3825,7 @@ namespace Slang IRBuilder builder; builder.sharedBuilder = &shared; - return builder.createGlobalVar(valueType); + return builder.createGlobalParam(valueType); } void moveValueBefore( @@ -4277,18 +4290,26 @@ namespace Slang varLayout->stage = inVarLayout->stage; varLayout->AddResourceInfo(kind)->index = bindingIndex; - // Simple case: just create a global variable of the matching type, - // and then use the value of the global as a replacement for the - // value of the original parameter. + // We are going to be creating a global parameter to replace + // the function parameter, but we need to handle the case + // where the parameter represents a varying *output* and not + // just an input. + // + // Our IR global shader parameters are read-only, just + // like our IR function parameters, and need a wrapper + // `Out<...>` type to represent otuputs. // - auto globalVariable = addGlobalVariable(builder->getModule(), type); - moveValueBefore(globalVariable, builder->getFunc()); + bool isOutput = kind == LayoutResourceKind::VaryingOutput; + IRType* paramType = isOutput ? builder->getOutType(type) : type; + + auto globalParam = addGlobalParam(builder->getModule(), paramType); + moveValueBefore(globalParam, builder->getFunc()); - ScalarizedVal val = ScalarizedVal::address(globalVariable); + ScalarizedVal val = isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam); if( systemValueInfo ) { - builder->addImportDecoration(globalVariable, UnownedTerminatedStringSlice(systemValueInfo->name)); + builder->addImportDecoration(globalParam, UnownedTerminatedStringSlice(systemValueInfo->name)); if( auto fromType = systemValueInfo->requiredType ) { @@ -4309,11 +4330,11 @@ namespace Slang if(auto outerArrayName = systemValueInfo->outerArrayName) { - builder->addGLSLOuterArrayDecoration(globalVariable, UnownedTerminatedStringSlice(outerArrayName)); + builder->addGLSLOuterArrayDecoration(globalParam, UnownedTerminatedStringSlice(outerArrayName)); } } - builder->addLayoutDecoration(globalVariable, varLayout); + builder->addLayoutDecoration(globalParam, varLayout); return val; } @@ -4865,46 +4886,22 @@ namespace Slang auto builder = context->getBuilder(); auto paramType = pp->getDataType(); - if(auto paramPtrType = as<IROutTypeBase>(paramType) ) - { - // This is either an `out` or `in out` parameter. - // We want to treat `out` parameters the same - // as `in out` for our purposes, since there are - // no pure `out` parameters defined for the ray - // tracing stages. - - // Unlike the default legalization strategy for - // `out` and `in out` entry point parameters, - // we will not introduce an intermediate temporary. - // - // Instead we will simply create a global variable - // and replace uses of the parameter with uses - // of that global variable. - - auto valueType = paramPtrType->getValueType(); - - auto globalVariable = addGlobalVariable(builder->getModule(), valueType); - builder->addLayoutDecoration(globalVariable, paramLayout); - moveValueBefore(globalVariable, builder->getFunc()); - - pp->replaceUsesWith(globalVariable); - } - else - { - // This is the `in` parameter case, so that the parameter - // was not a pointer. We will allocate a global variable - // to represent the parameter, and then perform a load - // form it at the start of the function. - // - auto valueType = paramType; - auto globalVariable = addGlobalVariable(builder->getModule(), valueType); - builder->addLayoutDecoration(globalVariable, paramLayout); - moveValueBefore(globalVariable, builder->getFunc()); - - auto irLoad = builder->emitLoad(globalVariable); - pp->replaceUsesWith(irLoad); - } - + // The parameter might be either an `in` parameter, + // or an `out` or `in out` parameter, and in those + // latter cases its IR-level type will include a + // wrapping "pointer-like" type (e.g., `Out<Float>` + // instead of just `Float`). + // + // Because global shader parameters are read-only + // in the same way function types are, we can take + // care of that detail here just by allocating a + // global shader parameter with exactly the type + // of the original function parameter. + // + auto globalParam = addGlobalParam(builder->getModule(), paramType); + builder->addLayoutDecoration(globalParam, paramLayout); + moveValueBefore(globalParam, builder->getFunc()); + pp->replaceUsesWith(globalParam); } void legalizeEntryPointParameterForGLSL( @@ -5629,6 +5626,7 @@ namespace Slang case kIROp_Generic: case kIROp_GlobalVar: case kIROp_GlobalConstant: + case kIROp_GlobalParam: case kIROp_StructKey: case kIROp_GlobalGenericParam: case kIROp_WitnessTable: @@ -5779,21 +5777,6 @@ namespace Slang registerClonedValue(context, clonedVar, originalValues); -#if 0 - auto mangledName = originalVar->mangledName; - clonedVar->mangledName = mangledName; -#endif - - if(auto linkage = originalVar->findDecoration<IRLinkageDecoration>()) - { - auto mangledName = String(linkage->getMangledName()); - VarLayout* layout = nullptr; - if (context->globalVarLayouts.TryGetValue(mangledName, layout)) - { - builder->addLayoutDecoration(clonedVar, layout); - } - } - // Clone any code in the body of the variable, since this // represents the initializer. cloneGlobalValueWithCodeCommon( @@ -5824,25 +5807,6 @@ namespace Slang return clonedVal; } - IRGeneric* cloneGenericImpl( - IRSpecContextBase* context, - IRBuilder* builder, - IRGeneric* originalVal, - IROriginalValuesForClone const& originalValues) - { - auto clonedVal = builder->emitGeneric(); - registerClonedValue(context, clonedVal, originalValues); - - // Clone any code in the body of the generic, since this - // computes its result value. - cloneGlobalValueWithCodeCommon( - context, - clonedVal, - originalVal); - - return clonedVal; - } - void cloneSimpleGlobalValueImpl( IRSpecContextBase* context, IRInst* originalInst, @@ -5865,6 +5829,48 @@ namespace Slang } } + IRGlobalParam* cloneGlobalParamImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGlobalParam* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->createGlobalParam( + cloneType(context, originalVal->getFullType())); + cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); + + if(auto linkage = originalVal->findDecoration<IRLinkageDecoration>()) + { + auto mangledName = String(linkage->getMangledName()); + VarLayout* layout = nullptr; + if (context->globalVarLayouts.TryGetValue(mangledName, layout)) + { + builder->addLayoutDecoration(clonedVal, layout); + } + } + + return clonedVal; + } + + IRGeneric* cloneGenericImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRGeneric* originalVal, + IROriginalValuesForClone const& originalValues) + { + auto clonedVal = builder->emitGeneric(); + registerClonedValue(context, clonedVal, originalValues); + + // Clone any code in the body of the generic, since this + // computes its result value. + cloneGlobalValueWithCodeCommon( + context, + clonedVal, + originalVal); + + return clonedVal; + } + IRStructKey* cloneStructKeyImpl( IRSpecContextBase* context, IRBuilder* builder, @@ -6254,6 +6260,7 @@ namespace Slang case kIROp_StructType: case kIROp_GlobalVar: + case kIROp_GlobalParam: return true; default: @@ -6350,6 +6357,9 @@ namespace Slang case kIROp_GlobalConstant: return cloneGlobalConstantImpl(context, builder, cast<IRGlobalConstant>(originalInst), originalValues); + case kIROp_GlobalParam: + return cloneGlobalParamImpl(context, builder, cast<IRGlobalParam>(originalInst), originalValues); + case kIROp_WitnessTable: return cloneWitnessTableImpl(context, builder, cast<IRWitnessTable>(originalInst), originalValues); diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index bf5aeff24..18d42feab 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -3,6 +3,7 @@ #include "../../slang.h" +#include "check.h" #include "ir.h" #include "ir-constexpr.h" #include "ir-insts.h" @@ -3828,8 +3829,41 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return false; } + LoweredValInfo lowerGlobalShaderParam(VarDeclBase* decl) + { + IRType* paramType = lowerType(context, decl->getType()); + + auto builder = getBuilder(); + + auto irParam = builder->createGlobalParam(paramType); + auto paramVal = LoweredValInfo::simple(irParam); + + addLinkageDecoration(context, irParam, decl); + addNameHint(context, irParam, decl); + maybeSetRate(context, irParam, decl); + addVarDecorations(context, irParam, decl); + + if (decl) + { + builder->addHighLevelDeclDecoration(irParam, decl); + } + + // A global variable's SSA value is a *pointer* to + // the underlying storage. + setGlobalValue(context, decl, paramVal); + + irParam->moveToEnd(); + + return paramVal; + } + LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl) { + if(isGlobalShaderParameter(decl)) + { + return lowerGlobalShaderParam(decl); + } + IRType* varType = lowerType(context, decl->getType()); auto builder = getBuilder(); diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index c502780df..427127c05 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -171,6 +171,7 @@ </ItemDefinitionGroup> <ItemGroup> <ClInclude Include="..\..\slang.h" /> + <ClInclude Include="check.h" /> <ClInclude Include="compiler.h" /> <ClInclude Include="core.meta.slang.h" /> <ClInclude Include="decl-defs.h" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index dc5630504..d72909bc1 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -162,6 +162,9 @@ <ClInclude Include="visitor.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="check.h"> + <Filter>Header Files</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <ClCompile Include="check.cpp"> diff --git a/tests/bindings/array-of-struct-of-resource.hlsl b/tests/bindings/array-of-struct-of-resource.hlsl index b34e0469b..240ffed73 100644 --- a/tests/bindings/array-of-struct-of-resource.hlsl +++ b/tests/bindings/array-of-struct-of-resource.hlsl @@ -4,7 +4,7 @@ // HLSL compiler would already do in the simple case (when // all shader parameters are actually used). -float4 use(Texture2D t, SamplerState s) { return t.Sample(s, 0.0); } +float4 use(Texture2D t, SamplerState samp) { return t.Sample(samp, 0.0); } #ifdef __SLANG__ diff --git a/tests/bindings/binding0.hlsl b/tests/bindings/binding0.hlsl index 5516b0135..2ae40ead3 100644 --- a/tests/bindings/binding0.hlsl +++ b/tests/bindings/binding0.hlsl @@ -24,7 +24,7 @@ #endif float4 use(float4 val) { return val; }; -float4 use(Texture2D t, SamplerState s) { return t.Sample(s, 0.0); } +float4 use(Texture2D tex, SamplerState samp) { return tex.Sample(samp, 0.0); } Texture2D t R(: register(t0)); SamplerState s R(: register(s0)); diff --git a/tests/cross-compile/non-uniform-indexing.slang.glsl b/tests/cross-compile/non-uniform-indexing.slang.glsl index 83f63c70d..5ea5ed73a 100644 --- a/tests/cross-compile/non-uniform-indexing.slang.glsl +++ b/tests/cross-compile/non-uniform-indexing.slang.glsl @@ -17,12 +17,10 @@ in vec3 _S2; void main() { - vec3 _S3 = _S2; + int _S3 = nonuniformEXT(int(_S2.z)); - int _S4 = nonuniformEXT(int(_S3.z)); + vec4 _S4 = texture(sampler2D(t_0[_S3],s_0), _S2.xy); - vec4 _S5 = texture(sampler2D(t_0[_S4],s_0), _S3.xy); - - _S1 = _S5; + _S1 = _S4; return; } diff --git a/tests/vkray/anyhit.slang.glsl b/tests/vkray/anyhit.slang.glsl index 07789cdbd..622080399 100644 --- a/tests/vkray/anyhit.slang.glsl +++ b/tests/vkray/anyhit.slang.glsl @@ -33,13 +33,11 @@ rayPayloadInNV ShadowRay_0 _S3; void main() { - SphereHitAttributes_0 _S4 = _S2; - if(bool(gParams_0._data.mode_0)) { float val_0 = textureLod( sampler2D(gParams_alphaMap_0, gParams_sampler_0), - _S4.normal_0.xy, + _S2.normal_0.xy, float(0)).x; diff --git a/tests/vkray/closesthit.slang.glsl b/tests/vkray/closesthit.slang.glsl index 79fd3afbe..46c5ea636 100644 --- a/tests/vkray/closesthit.slang.glsl +++ b/tests/vkray/closesthit.slang.glsl @@ -6,15 +6,14 @@ #define tmp_colors _S2 #define tmp_hitattrs _S3 #define tmp_payload _S4 -#define tmp_localattrs _S5 -#define tmp_customidx _S6 -#define tmp_instanceid _S7 -#define tmp_add_0 _S8 -#define tmp_primid _S9 -#define tmp_add_1 _S10 -#define tmp_hitkind _S11 -#define tmp_hitt _S12 -#define tmp_tmin _S13 +#define tmp_customidx _S5 +#define tmp_instanceid _S6 +#define tmp_add_0 _S7 +#define tmp_primid _S8 +#define tmp_add_1 _S9 +#define tmp_hitkind _S10 +#define tmp_hitt _S11 +#define tmp_tmin _S12 struct SLANG_ParameterGroup_ShaderRecord_0 { @@ -49,8 +48,6 @@ rayPayloadInNV ReflectionRay_0 tmp_payload; void main() { - BuiltInTriangleIntersectionAttributes_0 tmp_localattrs = tmp_hitattrs; - uint tmp_customidx = gl_InstanceCustomIndexNV; uint tmp_instanceid = gl_InstanceID; |
