diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2023-06-30 15:25:59 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-06-30 12:25:59 -0700 |
| commit | c5b0708ead5de2d90ef14f20b5b8e3ed4f576373 (patch) | |
| tree | 0e7691cda679b5f937d7ed73eba021ad9170245c /source | |
| parent | a3ad4dd77bba6c87abad4f76b72055c9fed94bad (diff) | |
Fix for operator assignment issue (#2951)
* WIP handling LValue coercion via LValueImplicitCast
* Need to have the ptr type for the cast.
* Casting conversion working on C++.
* Make the LValue casts record if in or in/out as we can produce better code if we know the difference.
* WIP LValueCast pass
* Fix tests so we don't fail because downstream compilers detect use of uninitialized variable.
* Do conversions through through tmp for l-value scenarios that can't work other ways.
* Fix a typo.
* Change diagnostic implicit-cast-lvalue for a type that still exhibits the issue.
* Add matrix test.
* Added a bit more clarity around LValue casting choices.
* Small comment improvements.
Improvements based on comments on PR.
* Use findOuterGeneric.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-expr.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 107 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-l-value-cast.cpp | 247 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-l-value-cast.h | 24 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.cpp | 56 |
12 files changed, 513 insertions, 12 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index da213e8d4..36d6546de 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -302,6 +302,32 @@ class ImplicitCastExpr : public TypeCastExpr SLANG_AST_CLASS(ImplicitCastExpr) }; +class LValueImplicitCastExpr : public TypeCastExpr +{ + SLANG_AST_CLASS(LValueImplicitCastExpr) + + explicit LValueImplicitCastExpr(const TypeCastExpr& rhs) :Super(rhs) {} +}; + +// To work around situations like int += uint +// where we want to allow an LValue to work with an implicit cast. +// The argument being cast *must* be an LValue. +class OutImplicitCastExpr : public LValueImplicitCastExpr +{ + SLANG_AST_CLASS(OutImplicitCastExpr) + + /// Allow explict construction from any TypeCastExpr + explicit OutImplicitCastExpr(const TypeCastExpr& rhs) :Super(rhs) {} +}; + +class InOutImplicitCastExpr : public LValueImplicitCastExpr +{ + SLANG_AST_CLASS(InOutImplicitCastExpr) + + /// Allow explict construction from any TypeCastExpr + explicit InOutImplicitCastExpr(const TypeCastExpr& rhs) :Super(rhs) {} +}; + /// A cast of a value to a super-type of its type. /// /// The type being cast to is stored as this expression's `type`. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index e0084e08e..00ece3628 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1952,6 +1952,51 @@ namespace Slang return checkedExpr; } + static bool _canLValueCoerceScalarType(Type* a, Type* b) + { + auto basicTypeA = as<BasicExpressionType>(a); + auto basicTypeB = as<BasicExpressionType>(b); + + if (basicTypeA && basicTypeB) + { + const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->baseType); + const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->baseType); + + // TODO(JS): Initially this tries to limit where LValueImplict casts happen. + // We could in principal allow different sizes, as long as we converted to a temprorary + // and back again. + // + // For now we just stick with the simple case. + // // We only allow on integer types for now. In effect just allowing any size uint/int conversions + if (infoA.sizeInBytes == infoB.sizeInBytes && + (infoA.flags & infoB.flags & BaseTypeInfo::Flag::Integer)) + { + return true; + } + + } + return false; + } + + static bool _canLValueCoerce(Type* a, Type* b) + { + // We can *assume* here that if they are coercable, that dimensions of vectors + // and matrices match. We might want to assert to be sure... + SLANG_ASSERT(a != b); + if (a->astNodeType == b->astNodeType) + { + if (auto matA = as<MatrixExpressionType>(a)) + { + return _canLValueCoerceScalarType(matA->getElementType(), static_cast<MatrixExpressionType*>(b)->getElementType()); + } + else if (auto vecA = as<VectorExpressionType>(a)) + { + return _canLValueCoerceScalarType(vecA->getScalarType(), static_cast<VectorExpressionType*>(b)->getScalarType()); + } + } + return _canLValueCoerceScalarType(a, b); + } + Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr *expr) { auto rs = ResolveInvoke(expr); @@ -1985,23 +2030,63 @@ namespace Slang if( pp < expr->arguments.getCount() ) { auto argExpr = expr->arguments[pp]; - if( !argExpr->type.isLeftValue ) + if( !argExpr->type.isLeftValue) { - getSink()->diagnose( - argExpr, - Diagnostics::argumentExpectedLValue, - pp); + auto implicitCastExpr = as<ImplicitCastExpr>(argExpr); - if( auto implicitCastExpr = as<ImplicitCastExpr>(argExpr) ) + if (implicitCastExpr && _canLValueCoerce(implicitCastExpr->arguments[0]->type, implicitCastExpr->type)) + { + // This is to work around issues like + // + // ``` + // int a = 0; + // uint b = 1; + // a += b; + // ``` + // That strictly speaking it's not allowed, but we are going to allow it for now + // for situations were the types are uint/int and vector/matrix varieties of those types + // + // Then in lowering we are going to insert code to do something like + // ``` + // var OutType: tmp = arg; + // f(... tmp); + // arg = tmp; + // ``` + + TypeCastExpr* lValueImplicitCast; + + // We want to record if the cast is being used for `out` or `inout`/`ref` as + // if it's just `out` we won't need to convert before passing in. + if (as<OutType>(paramType)) + { + lValueImplicitCast = getASTBuilder()->create<OutImplicitCastExpr>(*implicitCastExpr); + } + else + { + lValueImplicitCast = getASTBuilder()->create<InOutImplicitCastExpr>(*implicitCastExpr); + } + + // Replace the expression. This should make this situation easier to detect. + expr->arguments[pp] = lValueImplicitCast; + } + else { getSink()->diagnose( argExpr, - Diagnostics::implicitCastUsedAsLValue, - implicitCastExpr->arguments[0]->type, - implicitCastExpr->type); + Diagnostics::argumentExpectedLValue, + pp); + + if(implicitCastExpr) + { + getSink()->diagnose( + argExpr, + Diagnostics::implicitCastUsedAsLValue, + implicitCastExpr->arguments[pp]->type, + implicitCastExpr->type); + } + + maybeDiagnoseThisNotLValue(argExpr); } - - maybeDiagnoseThisNotLValue(argExpr); } } else diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 4e396e1f5..7082eccd7 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1633,6 +1633,12 @@ namespace Slang /// Are we generating code for a CUDA API (CUDA / OptiX)? bool isCUDATarget(TargetRequest* targetReq); + /// Given a target request returns which (if any) intermediate source language is required + /// to produce it. + /// + /// If no intermediate source language is required, will return SourceLanguage::Unknown + SourceLanguage getIntermediateSourceLanguageForTarget(TargetRequest* req); + /// Are resource types "bindless" (implemented as ordinary data) on the given `target`? bool areResourceTypesBindlessOnTarget(TargetRequest* target); diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 49edec92b..ef9b7f54e 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1200,6 +1200,20 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut { return false; } + + case kIROp_InOutImplicitCast: + case kIROp_OutImplicitCast: + { + // We'll just the LValue to be the desired type + m_writer->emit("reinterpret_cast<"); + emitType(inst->getDataType()); + m_writer->emit(">("); + + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + + m_writer->emit(")"); + return true; + } case kIROp_MakeVector: { IRType* retType = inst->getFullType(); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index bd6542e2a..827c69e50 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -35,6 +35,7 @@ #include "slang-ir-lower-result-type.h" #include "slang-ir-lower-optional-type.h" #include "slang-ir-lower-bit-cast.h" +#include "slang-ir-lower-l-value-cast.h" #include "slang-ir-lower-reinterpret.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-metadata.h" @@ -866,6 +867,9 @@ Result linkAndOptimizeIR( legalizeUniformBufferLoad(irModule); } + // Lower all the LValue implict casts (used for out/inout/ref scenarios) + lowerLValueCast(targetRequest, irModule); + // Lower all bit_cast operations on complex types into leaf-level // bit_cast on basic types. lowerBitCast(targetRequest, irModule); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b33565700..de3735a55 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -923,6 +923,8 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) +INST(OutImplicitCast, outImplicitCast, 1, 0) +INST(InOutImplicitCast, inOutImplicitCast, 1, 0) INST(IntCast, intCast, 1, 0) INST(FloatCast, floatCast, 1, 0) INST(CastIntToFloat, castIntToFloat, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4c540bfdd..f2c00f406 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3302,6 +3302,8 @@ public: IRUndefined* emitUndefined(IRType* type); IRInst* emitReinterpret(IRInst* type, IRInst* value); + IRInst* emitOutImplicitCast(IRInst* type, IRInst* value); + IRInst* emitInOutImplicitCast(IRInst* type, IRInst* value); IRFunc* createFunc(); IRGlobalVar* createGlobalVar( diff --git a/source/slang/slang-ir-lower-l-value-cast.cpp b/source/slang/slang-ir-lower-l-value-cast.cpp new file mode 100644 index 000000000..cd03d2bd5 --- /dev/null +++ b/source/slang/slang-ir-lower-l-value-cast.cpp @@ -0,0 +1,247 @@ +#include "slang-ir-lower-l-value-cast.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +#include "slang-ir-clone.h" + +#include "slang-ir-util.h" + +namespace Slang +{ + +struct LValueCastLoweringContext +{ + void _addToWorkList(IRInst* inst) + { + if (!findOuterGeneric(inst) && !m_workList.contains(inst)) + { + m_workList.add(inst); + } + } + + void _processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_InOutImplicitCast: + case kIROp_OutImplicitCast: + _processLValueCast(inst); + break; + default: + break; + } + } + + void processModule() + { + _addToWorkList(m_module->getModuleInst()); + + while (m_workList.getCount() != 0) + { + IRInst* inst = m_workList.getLast(); + m_workList.removeLast(); + + _processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + _addToWorkList(child); + } + } + } + + /// True if the conversion from a to b, can be achieved + /// via a reinterpret cast/bitcast + /// Only some targets can allow such conversions + bool _canReinterpretCast(IRType* a, IRType* b) + { + auto ptrA = as<IRPtrType>(a); + auto ptrB = as<IRPtrType>(b); + + // They must both be pointers... + SLANG_ASSERT(ptrA && ptrB); + + a = ptrA->getValueType(); + b = ptrB->getValueType(); + + if (a->m_op == b->m_op) + { + if (auto matA = as<IRMatrixType>(a)) + { + auto matB = static_cast<IRMatrixType*>(b); + + if (getIntVal(matA->getColumnCount()) != getIntVal(matB->getColumnCount())) + { + return false; + } + + a = matA->getElementType(); + b = matB->getElementType(); + } + else if (auto vecA = as<IRVectorType>(a)) + { + auto vecB = static_cast<IRVectorType*>(b); + + if (getIntVal(vecA->getElementCount()) != getIntVal(vecB->getElementCount())) + { + return false; + } + + a = vecA->getElementType(); + b = vecB->getElementType(); + } + } + + auto basicA = as<IRBasicType>(a); + auto basicB = as<IRBasicType>(b); + + if (basicA && basicB) + { + auto baseA = basicA->getBaseType(); + auto baseB = basicB->getBaseType(); + + const auto& infoA = BaseTypeInfo::getInfo(baseA); + const auto& infoB = BaseTypeInfo::getInfo(baseB); + + // We allow reinterpret case for int type conversions of the same bit size for now + if (infoA.sizeInBytes == infoB.sizeInBytes && + (infoA.flags & infoB.flags & BaseTypeInfo::Flag::Integer)) + { + return true; + } + } + + return false; + } + + /// True if for HLSL the cast can be removed entirely + bool _canRemoveCastForHLSL(IRType* a, IRType* b) + { + // Currently _canReinterpret is exactly the same class of types that we can just ignore the cast totally + // for HLSL + // If _canReinterpretCast changes, this will need to be updated + return _canReinterpretCast(a, b); + } + + void _processLValueCast(IRInst* castInst) + { + auto castOperand = castInst->getOperand(0); + auto fromType = castOperand->getDataType(); + auto toType = castInst->getDataType(); + + switch (m_intermediateSourceLanguage) + { + case SourceLanguage::HLSL: + { + // If the conversion can just be ignored for HLSL, just remove it + if (_canRemoveCastForHLSL(fromType, toType)) + { + castInst->replaceUsesWith(castOperand); + castInst->removeAndDeallocate(); + return; + } + break; + } + case SourceLanguage::C: + case SourceLanguage::CPP: + case SourceLanguage::CUDA: + { + // For languages with pointers, out parameter differences can *sometimes* just be sidestepped with + // a reinterpret cast. + if (_canReinterpretCast(fromType, toType)) + { + return; + } + break; + } + default: break; + } + + // If we can't use the other mechanisms we are going to do a conversion + // via a cast into a temporary of the approprite time before the useSite, + // then immediately after converting back into the original location. + // + // With a special case for uses which are just out - where we don't need to + // convert in. + + // Okay we are going to replace the implicit casts with temporaries around call sites/uses. + List<IRInst*> useSites; + for (auto use = castInst->firstUse; use; use = use->nextUse) + { + auto useSite = use->getUser(); + + if (useSites.indexOf(useSite) < 0) + { + useSites.add(useSite); + } + } + + // If there is a name hint on the source, we'll copy it over to the temporaries + auto nameHintDecoration = castOperand->findDecoration<IRNameHintDecoration>(); + + IRBuilder builder(m_module); + + IRType* toValueType = as<IRPtrType>(toType)->getValueType(); + IRType* fromValueType = as<IRPtrType>(fromType)->getValueType(); + + for (auto useSite : useSites) + { + builder.setInsertBefore(useSite); + auto tmpVar = builder.emitVar(toValueType); + + if (nameHintDecoration) + { + cloneDecoration(nameHintDecoration, tmpVar); + } + + // If it's inout we convert via cast whats in the castOperand + if (castInst->getOp() == kIROp_InOutImplicitCast) + { + builder.emitStore(tmpVar, builder.emitCast(toValueType, builder.emitLoad(castOperand))); + } + + // Convert the temporary back to the original location + builder.setInsertAfter(useSite); + builder.emitStore(castOperand, builder.emitCast(fromValueType, builder.emitLoad(tmpVar))); + + // Go through all of the operands of the use inst relacing, with the temporary + const auto operandCount = Count(useSite->getOperandCount()); + auto operands = useSite->getOperands(); + + for (Index i = 0; i < operandCount; ++i) + { + auto& callSiteOperand = operands[i]; + if(callSiteOperand.get() == castInst) + { + callSiteOperand.set(tmpVar); + } + } + } + + // When we are done we can destroy the inst + castInst->removeAndDeallocate(); + } + + LValueCastLoweringContext(TargetRequest* targetRequest, IRModule* module): + m_targetReq(targetRequest), + m_module(module) + { + m_intermediateSourceLanguage = getIntermediateSourceLanguageForTarget(targetRequest); + } + + // The intermediate source language used to produce code for the target. + // If no intermediate source language is used will be SourceLanguage::Unknown. + SourceLanguage m_intermediateSourceLanguage = SourceLanguage::Unknown; + TargetRequest* m_targetReq; + IRModule* m_module; + OrderedHashSet<IRInst*> m_workList; +}; + +void lowerLValueCast(TargetRequest* targetReq, IRModule* module) +{ + LValueCastLoweringContext context(targetReq, module); + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-l-value-cast.h b/source/slang/slang-ir-lower-l-value-cast.h new file mode 100644 index 000000000..3d93d1200 --- /dev/null +++ b/source/slang/slang-ir-lower-l-value-cast.h @@ -0,0 +1,24 @@ +#ifndef SLANG_IR_LOWER_L_VALUE_CAST_H +#define SLANG_IR_LOWER_L_VALUE_CAST_H + +// This defines an IR pass that lowers LValue implicit casts. These are typically formed +// when an in/inout paramter is passed a type that doesn't match. +// +// Depending on the target this could produce +// +// * Nothing - some kinds of casts are implicit for some targets such as HLSL on out parameters for same sized integer types +// * A reinterpret cast. On targets with pointers, such as C++/CUDA we can fix the problem by just casting to the appropriate pointer (for some kinds of conversions) +// * Creating a temporary of the right type and calling the function, and *converting* to the target (say an out parameter) +// * Creating a temporary, converting the value into the temporary, calling the function, and converting back to the source + +namespace Slang +{ + +struct IRModule; +class TargetRequest; + +void lowerLValueCast(TargetRequest* targetReq, IRModule* module); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index fb121d245..789349b4c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3130,7 +3130,14 @@ namespace Slang { return emitIntrinsicInst((IRType*)type, kIROp_Reinterpret, 1, &value); } - + IRInst* IRBuilder::emitInOutImplicitCast(IRInst* type, IRInst* value) + { + return emitIntrinsicInst((IRType*)type, kIROp_InOutImplicitCast, 1, &value); + } + IRInst* IRBuilder::emitOutImplicitCast(IRInst* type, IRInst* value) + { + return emitIntrinsicInst((IRType*)type, kIROp_OutImplicitCast, 1, &value); + } IRLiveRangeStart* IRBuilder::emitLiveRangeStart(IRInst* referenced) { // This instruction doesn't produce any result, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 58768f2ad..85b17dafb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4677,6 +4677,34 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis { static bool _isLValueContext() { return true; } + LoweredValInfo visitLValueImplicitCastExpr(LValueImplicitCastExpr* expr) + { + auto builder = getBuilder(); + + auto irType = lowerType(context, expr->type); + auto irPtrType = builder->getPtrType(irType); + + auto loweredArg = lowerLValueExpr(context, expr->arguments[0]); + + // It should be a ptr, because it is a LValue + SLANG_ASSERT(loweredArg.flavor == LoweredValInfo::Flavor::Ptr); + + // We have the irValue (which should be a Ptr because it's an LValue) + auto irLValue = loweredArg.val; + + IRInst* irCast = nullptr; + if (as<OutImplicitCastExpr>(expr)) + { + irCast = builder->emitOutImplicitCast(irPtrType, irLValue); + } + else + { + irCast = builder->emitInOutImplicitCast(irPtrType, irLValue); + } + + return LoweredValInfo::ptr(irCast); + } + // When visiting a swizzle expression in an l-value context, // we need to construct a "swizzled l-value." LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 29cf86f5e..4d52b76eb 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1689,6 +1689,62 @@ bool isCUDATarget(TargetRequest* targetReq) } } +SourceLanguage getIntermediateSourceLanguageForTarget(TargetRequest* req) +{ + // If we are emitting directly, there is no intermediate source language + if (req->shouldEmitSPIRVDirectly()) + { + return SourceLanguage::Unknown; + } + + switch (req->getTarget()) + { + case CodeGenTarget::GLSL: + case CodeGenTarget::GLSL_Vulkan: + case CodeGenTarget::GLSL_Vulkan_OneDesc: + // If we aren't emitting directly we are going to output GLSL to feed to GLSLANG + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + { + return SourceLanguage::GLSL; + } + case CodeGenTarget::HLSL: + case CodeGenTarget::DXBytecode: + case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::DXIL: + case CodeGenTarget::DXILAssembly: + { + // Currently DXBytecode and DXIL are generated via HLSL + return SourceLanguage::HLSL; + } + case CodeGenTarget::CSource: + { + return SourceLanguage::C; + } + case CodeGenTarget::ShaderSharedLibrary: + case CodeGenTarget::ObjectCode: + case CodeGenTarget::HostExecutable: + case CodeGenTarget::HostHostCallable: + case CodeGenTarget::ShaderHostCallable: + case CodeGenTarget::CPPSource: + case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: + { + // For CPU based scenarios are generated via C++ + return SourceLanguage::CPP; + } + case CodeGenTarget::CUDAObjectCode: + case CodeGenTarget::CUDASource: + case CodeGenTarget::PTX: + { + return SourceLanguage::CUDA; + } + default: break; + } + + return SourceLanguage::Unknown; +} + bool areResourceTypesBindlessOnTarget(TargetRequest* targetReq) { return isCPUTarget(targetReq) || isCUDATarget(targetReq); |
