summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-03 15:54:16 -0700
committerGitHub <noreply@github.com>2023-04-03 15:54:16 -0700
commitb68516e2c2e39af79dda2ec7871fe4d821ef67c4 (patch)
treeec61ca320368f8128cd531a9272e8e49d5353247
parent7a346b2982c69ef97ebc4b308c77a1f1c88c548f (diff)
Emit simpler vector element access code. (#2770)
* Emit simpler vector element access code * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-emit-cpp.cpp56
-rw-r--r--source/slang/slang-ir-legalize-types.cpp9
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp3
-rw-r--r--source/slang/slang-ir-lower-generic-type.cpp9
-rw-r--r--tests/bugs/type-legalize-bug-1.slang56
-rw-r--r--tests/bugs/type-legalize-bug-1.slang.expected.txt4
6 files changed, 114 insertions, 23 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index a178dfe67..d6764ae06 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1270,13 +1270,26 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
IRInst* baseInst = getElementInst->getBase();
IRType* baseType = baseInst->getDataType();
- if (as<IRVectorType>(baseType))
+ if (auto vectorBaseType = as<IRVectorType>(baseType))
{
- m_writer->emit("_slang_vector_get_element(");
- emitOperand(baseInst, getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
- m_writer->emit(")");
+ if (auto intLitIndex = as<IRIntLit>(getElementInst->getIndex()))
+ {
+ // For static index, we can emit simpler code using the `.x`, `.y` members.
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(baseInst, leftSide(outerPrec, prec));
+ m_writer->emit(".");
+ m_writer->emit(getVectorElementNames(vectorBaseType)[intLitIndex->getValue()]);
+ }
+ else
+ {
+ // For dynamic index, we emit using `_slang_vector_get_element` intrinsics.
+ m_writer->emit("_slang_vector_get_element(");
+ emitOperand(baseInst, getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
return true;
}
else if (as<IRMatrixType>(baseType))
@@ -1297,24 +1310,37 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
IRInst* baseInst = getElementInst->getBase();
IRType* baseType = as<IRPtrTypeBase>(baseInst->getDataType())->getValueType();
- if (as<IRVectorType>(baseType))
+ if (auto vectorBaseType = as<IRVectorType>(baseType))
{
- m_writer->emit("_slang_vector_get_element_ptr(");
- emitOperand(baseInst, getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
- m_writer->emit(")");
+ if (auto intLitIndex = as<IRIntLit>(getElementInst->getIndex()))
+ {
+ // For static index, we can emit simpler code using the `.x`, `.y` members.
+ m_writer->emit("&(");
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(baseInst, leftSide(outerPrec, prec));
+ m_writer->emit("->");
+ m_writer->emit(getVectorElementNames(vectorBaseType)[intLitIndex->getValue()]);
+ m_writer->emit(")");
+ }
+ else
+ {
+ m_writer->emit("_slang_vector_get_element_ptr(");
+ emitOperand(baseInst, getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
return true;
}
else if (as<IRMatrixType>(baseType))
{
- m_writer->emit("&(");
+ m_writer->emit("(");
auto outerPrec = getInfo(EmitOp::General);
auto prec = getInfo(EmitOp::Postfix);
emitOperand(baseInst, leftSide(outerPrec, prec));
- m_writer->emit("->rows[");
+ m_writer->emit("->rows + ");
emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
- m_writer->emit("]");
m_writer->emit(")");
return true;
}
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index b9d494c21..52b5bc72f 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -2159,6 +2159,8 @@ static LegalVal legalizeInst(
if (newArgs[aa] != inst->getOperand(aa))
recreate = true;
}
+ if (inst->getFullType() != legalType.getSimple())
+ recreate = true;
if (recreate)
{
IRBuilder builder(inst->getModule());
@@ -2169,8 +2171,6 @@ static LegalVal legalizeInst(
context->replacedInstructions.add(inst);
return LegalVal::simple(newInst);
}
- if (inst->getFullType() != legalType.getSimple())
- inst->setFullType(legalType.getSimple());
return LegalVal::simple(inst);
}
@@ -3641,7 +3641,10 @@ struct IRTypeLegalizationPass
//
// * `i` is a user of `inst`, or
// * `i` is a child of `inst`.
- //
+ //
+ if (legalVal.flavor == LegalVal::Flavor::simple)
+ inst = legalVal.irValue;
+
for( auto use = inst->firstUse; use; use = use->nextUse )
{
auto user = use->getUser();
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index ad43aff95..e45b20563 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -322,7 +322,8 @@ namespace Slang
return;
auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(interfaceType, lookupInst->getRequirementKey());
- lookupInst->setFullType((IRType*)interfaceRequirementVal);
+ IRBuilder builder(lookupInst);
+ builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal);
}
void lowerSpecialize(IRSpecialize* specializeInst)
diff --git a/source/slang/slang-ir-lower-generic-type.cpp b/source/slang/slang-ir-lower-generic-type.cpp
index 398db4f78..256978346 100644
--- a/source/slang/slang-ir-lower-generic-type.cpp
+++ b/source/slang/slang-ir-lower-generic-type.cpp
@@ -14,7 +14,7 @@ namespace Slang
{
SharedGenericsLoweringContext* sharedContext;
- void processInst(IRInst* inst)
+ IRInst* processInst(IRInst* inst)
{
// Ensure public struct types has RTTI object defined.
if (as<IRStructType>(inst))
@@ -27,7 +27,7 @@ namespace Slang
// Don't modify type insts themselves.
if (as<IRType>(inst))
- return;
+ return inst;
IRBuilder builderStorage(sharedContext->module);
auto builder = &builderStorage;
@@ -35,7 +35,7 @@ namespace Slang
auto newType = sharedContext->lowerType(builder, inst->getFullType());
if (newType != inst->getFullType())
- inst->setFullType((IRType*)newType);
+ inst = builder->replaceOperand(&inst->typeUse, newType);
switch (inst->getOp())
{
@@ -51,6 +51,7 @@ namespace Slang
}
break;
}
+ return inst;
}
void processModule()
@@ -64,7 +65,7 @@ namespace Slang
sharedContext->workList.removeLast();
sharedContext->workListSet.Remove(inst);
- processInst(inst);
+ inst = processInst(inst);
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
{
diff --git a/tests/bugs/type-legalize-bug-1.slang b/tests/bugs/type-legalize-bug-1.slang
new file mode 100644
index 000000000..83e28a509
--- /dev/null
+++ b/tests/bugs/type-legalize-bug-1.slang
@@ -0,0 +1,56 @@
+//TEST(compute):COMPARE_COMPUTE: -shaderobj
+
+//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name outputBuffer
+//TEST_INPUT:type_conformance A:IFoo=0
+//TEST_INPUT:type_conformance B:IFoo=1
+
+RWStructuredBuffer<int> outputBuffer : register(u0);
+interface IFoo
+{
+ associatedtype T : IFoo;
+ T getT();
+ void doSomething();
+}
+
+A createA() { return {}; }
+B createB() { return {}; }
+ParameterBlock<B> gB;
+void user()
+{
+ IFoo a = createDynamicObject<IFoo>(0, 0);
+ IFoo b = createDynamicObject<IFoo>(1, 0);
+ test(a.getT(), b);
+ test(a, gB.getT());
+}
+B test<T:IFoo>(T a, IFoo b)
+{
+ a.doSomething();
+ b.doSomething();
+ return {};
+}
+struct B :IFoo
+{
+ A a;
+ typealias T = A;
+ T getT() { return {};}
+ void doSomething()
+ {
+ outputBuffer[0] = 1;
+ }
+}
+struct A : IFoo
+{
+ typealias T = B;
+ T getT() { return {};}
+ void doSomething()
+ {
+ outputBuffer[0] = 1;
+ }
+}
+
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ user();
+} \ No newline at end of file
diff --git a/tests/bugs/type-legalize-bug-1.slang.expected.txt b/tests/bugs/type-legalize-bug-1.slang.expected.txt
new file mode 100644
index 000000000..89999d47a
--- /dev/null
+++ b/tests/bugs/type-legalize-bug-1.slang.expected.txt
@@ -0,0 +1,4 @@
+1
+9
+9
+9