summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-c-like.cpp60
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp60
-rw-r--r--source/slang/slang-serialize-ast-type-info.h2
-rw-r--r--tests/metal/vector-get-element-ptr.slang24
4 files changed, 140 insertions, 6 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 6412ff730..81b68038e 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1689,6 +1689,61 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const&
maybeCloseParens(innerNeedClose);
return;
}
+ case kIROp_GetElementPtr:
+ {
+ const auto info = getInfo(EmitOp::Prefix);
+ IRVectorType* vectorType = nullptr;
+ if (auto ptrType = as<IRPtrTypeBase>(inst->getOperand(0)->getDataType()))
+ {
+ vectorType = as<IRVectorType>(ptrType->getValueType());
+ }
+ if (vectorType)
+ {
+ // Can't use simplified emit logic for get vector element operations on CUDA targets.
+ if (isCUDATarget(m_codeGenContext->getTargetReq()))
+ break;
+ }
+
+ auto rightSidePrec = rightSide(outerPrec, info);
+ auto postfixInfo = getInfo(EmitOp::Postfix);
+ bool rightSideNeedClose = maybeEmitParens(rightSidePrec, postfixInfo);
+ emitDereferenceOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo));
+ bool emitBracketPostfix = true;
+ if (vectorType)
+ {
+ // Simplify the emitted code if we are referencing a known vector element.
+ if (auto intLit = as<IRIntLit>(inst->getOperand(1)))
+ {
+ emitBracketPostfix = false;
+ switch (intLit->getValue())
+ {
+ case 0:
+ m_writer->emit(".x");
+ break;
+ case 1:
+ m_writer->emit(".y");
+ break;
+ case 2:
+ m_writer->emit(".z");
+ break;
+ case 3:
+ m_writer->emit(".w");
+ break;
+ default:
+ emitBracketPostfix = true;
+ break;
+ }
+ }
+ }
+ if (emitBracketPostfix)
+ {
+ m_writer->emit("[");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("]");
+ }
+ maybeCloseParens(rightSideNeedClose);
+ return;
+ }
default:
break;
}
@@ -2538,10 +2593,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
auto rightSidePrec = rightSide(outerPrec, info);
auto postfixInfo = getInfo(EmitOp::Postfix);
bool rightSideNeedClose = maybeEmitParens(rightSidePrec, postfixInfo);
- if (isPtrToArrayType(inst->getOperand(0)->getDataType()))
- emitDereferenceOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo));
- else
- emitOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo));
+ emitDereferenceOperand(inst->getOperand(0), leftSide(rightSidePrec, postfixInfo));
m_writer->emit("[");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit("]");
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 70f4cbd27..6b6c86040 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -725,6 +725,65 @@ namespace Slang
legalizeDispatchMeshPayloadForMetal(entryPoint);
}
+ void legalizeFuncBody(IRFunc* func)
+ {
+ IRBuilder builder(func);
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getModifiableChildren())
+ {
+ if (auto call = as<IRCall>(inst))
+ {
+ ShortList<IRUse*> argsToFixup;
+ // Metal doesn't support taking the address of a vector element.
+ // If such an address is used as an argument to a call, we need to replace it with a temporary.
+ // for example, if we see:
+ // ```
+ // void foo(inout float x) { x = 1; }
+ // float4 v;
+ // foo(v.x);
+ // ```
+ // We need to transform it into:
+ // ```
+ // float4 v;
+ // float temp = v.x;
+ // foo(temp);
+ // v.x = temp;
+ // ```
+ //
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ if (auto addr = as<IRGetElementPtr>(call->getArg(i)))
+ {
+ auto ptrType = addr->getBase()->getDataType();
+ auto valueType = tryGetPointedToType(&builder, ptrType);
+ if (!valueType)
+ continue;
+ if (as<IRVectorType>(valueType))
+ argsToFixup.add(call->getArgs() + i);
+ }
+ }
+ if (argsToFixup.getCount() == 0)
+ continue;
+
+ // Define temp vars for all args that need fixing up.
+ for (auto arg : argsToFixup)
+ {
+ auto addr = as<IRGetElementPtr>(arg->get());
+ auto ptrType = addr->getDataType();
+ auto valueType = tryGetPointedToType(&builder, ptrType);
+ builder.setInsertBefore(call);
+ auto temp = builder.emitVar(valueType);
+ auto initialValue = builder.emitLoad(valueType, addr);
+ builder.emitStore(temp, initialValue);
+ builder.setInsertAfter(call);
+ builder.emitStore(addr, builder.emitLoad(valueType, temp));
+ arg->set(temp);
+ }
+ }
+ }
+ }
+ }
void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
{
@@ -740,6 +799,7 @@ namespace Slang
info.entryPointFunc = func;
entryPoints.add(info);
}
+ legalizeFuncBody(func);
}
}
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index a8f459247..20b6e656f 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -199,11 +199,9 @@ struct SerialTypeInfo<CapabilityTargetSet>
auto& shaderStageSets = dst.shaderStageSets;
shaderStageSets.clear();
shaderStageSets.reserve(items.getCount());
- Index iter = 0;
for (auto& i : items)
{
dst.shaderStageSets[i.stage] = i;
- iter++;
}
}
};
diff --git a/tests/metal/vector-get-element-ptr.slang b/tests/metal/vector-get-element-ptr.slang
new file mode 100644
index 000000000..af2acabbc
--- /dev/null
+++ b/tests/metal/vector-get-element-ptr.slang
@@ -0,0 +1,24 @@
+//TEST:SIMPLE(filecheck=CHECK): -target metal
+
+//TEST(smoke,compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -mtl
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+void modify(inout int v)
+{
+ v = 2;
+}
+
+[numthreads(1,1,1)]
+void computeMain(int3 v : SV_DispatchThreadID)
+{
+ int3 u = v;
+ // CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x;
+ // CHECK: modify{{.*}}(&[[TEMP]])
+ // CHECK: u{{.*}}.x = [[TEMP]];
+
+ modify(u.x);
+ // BUF: 2
+ outputBuffer[0] = u.x + u.y;
+} \ No newline at end of file