diff options
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ast-type-info.h | 2 | ||||
| -rw-r--r-- | tests/metal/vector-get-element-ptr.slang | 24 |
4 files changed, 140 insertions, 6 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 6412ff730..81b68038e 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1689,6 +1689,61 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& maybeCloseParens(innerNeedClose); return; } + case kIROp_GetElementPtr: + { + const auto info = getInfo(EmitOp::Prefix); + IRVectorType* vectorType = nullptr; + if (auto ptrType = as<IRPtrTypeBase>(inst->getOperand(0)->getDataType())) + { + vectorType = as<IRVectorType>(ptrType->getValueType()); + } + if (vectorType) + { + // Can't use simplified emit logic for get vector element operations on CUDA targets. + if (isCUDATarget(m_codeGenContext->getTargetReq())) + break; + } + + auto rightSidePrec = rightSide(outerPrec, info); + auto postfixInfo = getInfo(EmitOp::Postfix); + bool rightSideNeedClose = maybeEmitParens(rightSidePrec, postfixInfo); + emitDereferenceOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo)); + bool emitBracketPostfix = true; + if (vectorType) + { + // Simplify the emitted code if we are referencing a known vector element. + if (auto intLit = as<IRIntLit>(inst->getOperand(1))) + { + emitBracketPostfix = false; + switch (intLit->getValue()) + { + case 0: + m_writer->emit(".x"); + break; + case 1: + m_writer->emit(".y"); + break; + case 2: + m_writer->emit(".z"); + break; + case 3: + m_writer->emit(".w"); + break; + default: + emitBracketPostfix = true; + break; + } + } + } + if (emitBracketPostfix) + { + m_writer->emit("["); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("]"); + } + maybeCloseParens(rightSideNeedClose); + return; + } default: break; } @@ -2538,10 +2593,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO auto rightSidePrec = rightSide(outerPrec, info); auto postfixInfo = getInfo(EmitOp::Postfix); bool rightSideNeedClose = maybeEmitParens(rightSidePrec, postfixInfo); - if (isPtrToArrayType(inst->getOperand(0)->getDataType())) - emitDereferenceOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo)); - else - emitOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo)); + emitDereferenceOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo)); m_writer->emit("["); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); m_writer->emit("]"); diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 70f4cbd27..6b6c86040 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -725,6 +725,65 @@ namespace Slang legalizeDispatchMeshPayloadForMetal(entryPoint); } + void legalizeFuncBody(IRFunc* func) + { + IRBuilder builder(func); + for (auto block : func->getBlocks()) + { + for (auto inst : block->getModifiableChildren()) + { + if (auto call = as<IRCall>(inst)) + { + ShortList<IRUse*> argsToFixup; + // Metal doesn't support taking the address of a vector element. + // If such an address is used as an argument to a call, we need to replace it with a temporary. + // for example, if we see: + // ``` + // void foo(inout float x) { x = 1; } + // float4 v; + // foo(v.x); + // ``` + // We need to transform it into: + // ``` + // float4 v; + // float temp = v.x; + // foo(temp); + // v.x = temp; + // ``` + // + for (UInt i = 0; i < call->getArgCount(); i++) + { + if (auto addr = as<IRGetElementPtr>(call->getArg(i))) + { + auto ptrType = addr->getBase()->getDataType(); + auto valueType = tryGetPointedToType(&builder, ptrType); + if (!valueType) + continue; + if (as<IRVectorType>(valueType)) + argsToFixup.add(call->getArgs() + i); + } + } + if (argsToFixup.getCount() == 0) + continue; + + // Define temp vars for all args that need fixing up. + for (auto arg : argsToFixup) + { + auto addr = as<IRGetElementPtr>(arg->get()); + auto ptrType = addr->getDataType(); + auto valueType = tryGetPointedToType(&builder, ptrType); + builder.setInsertBefore(call); + auto temp = builder.emitVar(valueType); + auto initialValue = builder.emitLoad(valueType, addr); + builder.emitStore(temp, initialValue); + builder.setInsertAfter(call); + builder.emitStore(addr, builder.emitLoad(valueType, temp)); + arg->set(temp); + } + } + } + } + } void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { @@ -740,6 +799,7 @@ namespace Slang info.entryPointFunc = func; entryPoints.add(info); } + legalizeFuncBody(func); } } diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index a8f459247..20b6e656f 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -199,11 +199,9 @@ struct SerialTypeInfo<CapabilityTargetSet> auto& shaderStageSets = dst.shaderStageSets; shaderStageSets.clear(); shaderStageSets.reserve(items.getCount()); - Index iter = 0; for (auto& i : items) { dst.shaderStageSets[i.stage] = i; - iter++; } } }; diff --git a/tests/metal/vector-get-element-ptr.slang b/tests/metal/vector-get-element-ptr.slang new file mode 100644 index 000000000..af2acabbc --- /dev/null +++ b/tests/metal/vector-get-element-ptr.slang @@ -0,0 +1,24 @@ +//TEST:SIMPLE(filecheck=CHECK): -target metal + +//TEST(smoke,compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -mtl + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +void modify(inout int v) +{ + v = 2; +} + +[numthreads(1,1,1)] +void computeMain(int3 v : SV_DispatchThreadID) +{ + int3 u = v; + // CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x; + // CHECK: modify{{.*}}(&[[TEMP]]) + // CHECK: u{{.*}}.x = [[TEMP]]; + + modify(u.x); + // BUF: 2 + outputBuffer[0] = u.x + u.y; +}
\ No newline at end of file |
