diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-03 15:54:16 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-03 15:54:16 -0700 |
| commit | b68516e2c2e39af79dda2ec7871fe4d821ef67c4 (patch) | |
| tree | ec61ca320368f8128cd531a9272e8e49d5353247 | |
| parent | 7a346b2982c69ef97ebc4b308c77a1f1c88c548f (diff) | |
Emit simpler vector element access code. (#2770)
* Emit simpler vector element access code
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 56 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-function.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-type.cpp | 9 | ||||
| -rw-r--r-- | tests/bugs/type-legalize-bug-1.slang | 56 | ||||
| -rw-r--r-- | tests/bugs/type-legalize-bug-1.slang.expected.txt | 4 |
6 files changed, 114 insertions, 23 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index a178dfe67..d6764ae06 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1270,13 +1270,26 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut IRInst* baseInst = getElementInst->getBase(); IRType* baseType = baseInst->getDataType(); - if (as<IRVectorType>(baseType)) + if (auto vectorBaseType = as<IRVectorType>(baseType)) { - m_writer->emit("_slang_vector_get_element("); - emitOperand(baseInst, getInfo(EmitOp::General)); - m_writer->emit(", "); - emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); - m_writer->emit(")"); + if (auto intLitIndex = as<IRIntLit>(getElementInst->getIndex())) + { + // For static index, we can emit simpler code using the `.x`, `.y` members. + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(baseInst, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emit(getVectorElementNames(vectorBaseType)[intLitIndex->getValue()]); + } + else + { + // For dynamic index, we emit using `_slang_vector_get_element` intrinsics. + m_writer->emit("_slang_vector_get_element("); + emitOperand(baseInst, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit(")"); + } return true; } else if (as<IRMatrixType>(baseType)) @@ -1297,24 +1310,37 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut IRInst* baseInst = getElementInst->getBase(); IRType* baseType = as<IRPtrTypeBase>(baseInst->getDataType())->getValueType(); - if (as<IRVectorType>(baseType)) + if (auto vectorBaseType = as<IRVectorType>(baseType)) { - m_writer->emit("_slang_vector_get_element_ptr("); - emitOperand(baseInst, getInfo(EmitOp::General)); - m_writer->emit(", "); - emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); - m_writer->emit(")"); + if (auto intLitIndex = as<IRIntLit>(getElementInst->getIndex())) + { + // For static index, we can emit simpler code using the `.x`, `.y` members. + m_writer->emit("&("); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(baseInst, leftSide(outerPrec, prec)); + m_writer->emit("->"); + m_writer->emit(getVectorElementNames(vectorBaseType)[intLitIndex->getValue()]); + m_writer->emit(")"); + } + else + { + m_writer->emit("_slang_vector_get_element_ptr("); + emitOperand(baseInst, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit(")"); + } return true; } else if (as<IRMatrixType>(baseType)) { - m_writer->emit("&("); + m_writer->emit("("); auto outerPrec = getInfo(EmitOp::General); auto prec = getInfo(EmitOp::Postfix); emitOperand(baseInst, leftSide(outerPrec, prec)); - m_writer->emit("->rows["); + m_writer->emit("->rows + "); emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); - m_writer->emit("]"); m_writer->emit(")"); return true; } diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index b9d494c21..52b5bc72f 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2159,6 +2159,8 @@ static LegalVal legalizeInst( if (newArgs[aa] != inst->getOperand(aa)) recreate = true; } + if (inst->getFullType() != legalType.getSimple()) + recreate = true; if (recreate) { IRBuilder builder(inst->getModule()); @@ -2169,8 +2171,6 @@ static LegalVal legalizeInst( context->replacedInstructions.add(inst); return LegalVal::simple(newInst); } - if (inst->getFullType() != legalType.getSimple()) - inst->setFullType(legalType.getSimple()); return LegalVal::simple(inst); } @@ -3641,7 +3641,10 @@ struct IRTypeLegalizationPass // // * `i` is a user of `inst`, or // * `i` is a child of `inst`. - // + // + if (legalVal.flavor == LegalVal::Flavor::simple) + inst = legalVal.irValue; + for( auto use = inst->firstUse; use; use = use->nextUse ) { auto user = use->getUser(); diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index ad43aff95..e45b20563 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -322,7 +322,8 @@ namespace Slang return; auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(interfaceType, lookupInst->getRequirementKey()); - lookupInst->setFullType((IRType*)interfaceRequirementVal); + IRBuilder builder(lookupInst); + builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal); } void lowerSpecialize(IRSpecialize* specializeInst) diff --git a/source/slang/slang-ir-lower-generic-type.cpp b/source/slang/slang-ir-lower-generic-type.cpp index 398db4f78..256978346 100644 --- a/source/slang/slang-ir-lower-generic-type.cpp +++ b/source/slang/slang-ir-lower-generic-type.cpp @@ -14,7 +14,7 @@ namespace Slang { SharedGenericsLoweringContext* sharedContext; - void processInst(IRInst* inst) + IRInst* processInst(IRInst* inst) { // Ensure public struct types has RTTI object defined. if (as<IRStructType>(inst)) @@ -27,7 +27,7 @@ namespace Slang // Don't modify type insts themselves. if (as<IRType>(inst)) - return; + return inst; IRBuilder builderStorage(sharedContext->module); auto builder = &builderStorage; @@ -35,7 +35,7 @@ namespace Slang auto newType = sharedContext->lowerType(builder, inst->getFullType()); if (newType != inst->getFullType()) - inst->setFullType((IRType*)newType); + inst = builder->replaceOperand(&inst->typeUse, newType); switch (inst->getOp()) { @@ -51,6 +51,7 @@ namespace Slang } break; } + return inst; } void processModule() @@ -64,7 +65,7 @@ namespace Slang sharedContext->workList.removeLast(); sharedContext->workListSet.Remove(inst); - processInst(inst); + inst = processInst(inst); for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { diff --git a/tests/bugs/type-legalize-bug-1.slang b/tests/bugs/type-legalize-bug-1.slang new file mode 100644 index 000000000..83e28a509 --- /dev/null +++ b/tests/bugs/type-legalize-bug-1.slang @@ -0,0 +1,56 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj + +//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name outputBuffer +//TEST_INPUT:type_conformance A:IFoo=0 +//TEST_INPUT:type_conformance B:IFoo=1 + +RWStructuredBuffer<int> outputBuffer : register(u0); +interface IFoo +{ + associatedtype T : IFoo; + T getT(); + void doSomething(); +} + +A createA() { return {}; } +B createB() { return {}; } +ParameterBlock<B> gB; +void user() +{ + IFoo a = createDynamicObject<IFoo>(0, 0); + IFoo b = createDynamicObject<IFoo>(1, 0); + test(a.getT(), b); + test(a, gB.getT()); +} +B test<T:IFoo>(T a, IFoo b) +{ + a.doSomething(); + b.doSomething(); + return {}; +} +struct B :IFoo +{ + A a; + typealias T = A; + T getT() { return {};} + void doSomething() + { + outputBuffer[0] = 1; + } +} +struct A : IFoo +{ + typealias T = B; + T getT() { return {};} + void doSomething() + { + outputBuffer[0] = 1; + } +} + + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + user(); +}
\ No newline at end of file diff --git a/tests/bugs/type-legalize-bug-1.slang.expected.txt b/tests/bugs/type-legalize-bug-1.slang.expected.txt new file mode 100644 index 000000000..89999d47a --- /dev/null +++ b/tests/bugs/type-legalize-bug-1.slang.expected.txt @@ -0,0 +1,4 @@ +1 +9 +9 +9 |
