diff options
Diffstat (limited to 'source')
27 files changed, 1089 insertions, 102 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index f5b138f25..6c7a2c1d2 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -4793,7 +4793,7 @@ void __executeCallable(uint shaderIndex, int payloadLocation); __generic<Payload> __target_intrinsic(__glslRayTracing, "$XC") [__readNone] -int __callablePayloadLocation(Payload payload); +int __callablePayloadLocation(__ref Payload payload); // Now we provide a hard-coded definition of `CallShader()` for GLSL-based // targets, which maps the generic HLSL operation into the non-generic @@ -4848,7 +4848,7 @@ void __traceRay( __generic<Payload> __target_intrinsic(__glslRayTracing, "$XP") [__readNone] -int __rayPayloadLocation(Payload payload); +int __rayPayloadLocation(__ref Payload payload); __generic<payload_t> __specialized_for_target(glsl) @@ -5678,7 +5678,7 @@ struct VkSubpassInputMS<T> // We access the HitObjectAttributes via this function for the desired type, and it acts *as if* it's just an access // to the global t. [ForceInline] -__ref T __hitObjectAttributes<T>() +Ref<T> __hitObjectAttributes<T>() { [__vulkanHitObjectAttributes] static T t; @@ -5691,7 +5691,7 @@ __ref T __hitObjectAttributes<T>() __generic<Payload> __target_intrinsic(__glslRayTracing, "$XH") [__readNone] -int __hitObjectAttributesLocation(Payload payload); +int __hitObjectAttributesLocation(__ref Payload payload); /// Immutable data type representing a ray hit or a miss. Can be used to invoke hit or miss shading, /// or as a key in ReorderThread. Created by one of several methods described below. HitObject diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 9fdc10807..d49db89a2 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -268,6 +268,13 @@ class SwizzleExpr: public Expr SourceLoc memberOpLoc; }; +// An operation to convert an l-value to a reference type. +class MakeRefExpr : public Expr +{ + SLANG_AST_CLASS(MakeRefExpr) + Expr* base = nullptr; +}; + // A dereference of a pointer or pointer-like type class DerefExpr: public Expr { diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 21611bcb1..5fd9df400 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -78,6 +78,7 @@ namespace Slang // Conversion from a buffer to the type it carries needs to add a minimal // extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>` // from one on `Foo` + kConversionCost_GetRef = 5, kConversionCost_ImplicitDereference = 10, // Conversions based on explicit sub-typing relationships are the cheapest diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 6decce625..71231b9e5 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -642,11 +642,6 @@ namespace Slang return true; } - if (auto refType = as<RefType>(toType)) - { - return _coerce(refType->getValueType(), outToExpr, fromType, fromExpr, outCost); - } - // If both are string types we assume they are convertable in both directions if (as<StringTypeBase>(fromType) && as<StringTypeBase>(toType)) { @@ -860,6 +855,63 @@ namespace Slang return true; } + if (auto refType = as<RefType>(toType)) + { + if (!refType->getValueType()->equals(fromType)) + return false; + if (!fromExpr->type.isLeftValue) + return false; + + ConversionCost subCost = kConversionCost_GetRef; + + MakeRefExpr* refExpr = nullptr; + if (outToExpr) + { + refExpr = m_astBuilder->create<MakeRefExpr>(); + refExpr->base = fromExpr; + refExpr->type = QualType(refType); + refExpr->type.isLeftValue = false; + *outToExpr = refExpr; + } + if (outCost) + *outCost = subCost; + return true; + } + + + // Allow implicit dereferencing a reference type. + if (auto fromRefType = as<RefType>(fromType)) + { + auto fromValueType = fromRefType->getValueType(); + + // If we convert, e.g., `ConstantBuffer<A> to `A`, we will allow + // subsequent conversion of `A` to `B` if such a conversion + // is possible. + // + ConversionCost subCost = kConversionCost_None; + + Expr* openRefExpr = nullptr; + if (outToExpr) + { + openRefExpr = maybeOpenRef(fromExpr); + } + + if (!_coerce( + toType, + outToExpr, + fromValueType, + openRefExpr, + &subCost)) + { + return false; + } + + if (outCost) + *outCost = subCost + kConversionCost_ImplicitDereference; + return true; + } + + // The main general-purpose approach for conversion is // using suitable marked initializer ("constructor") // declarations on the target type. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 7e2bc3822..7d546e60b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1863,6 +1863,7 @@ namespace Slang Expr* SemanticsVisitor::checkAssignWithCheckedOperands(AssignExpr* expr) { + expr->left = maybeOpenRef(expr->left); auto type = expr->left->type; auto right = maybeOpenRef(expr->right); expr->right = coerce(type, right); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ac5fc8392..719706635 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1928,6 +1928,7 @@ namespace Slang } CASE(DerefExpr) + CASE(MakeRefExpr) CASE(MatrixSwizzleExpr) CASE(SwizzleExpr) CASE(OverloadedExpr) diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 160585e26..c664449e5 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1082,6 +1082,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) case kIROp_Param: case kIROp_Func: case kIROp_Alloca: + case kIROp_Store: return false; // Never fold these, because their result cannot be computed @@ -1997,17 +1998,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO } break; - case kIROp_Store: - { - auto prec = getInfo(EmitOp::Assign); - needClose = maybeEmitParens(outerPrec, prec); - - emitDereferenceOperand(inst->getOperand(0), leftSide(outerPrec, prec)); - m_writer->emit(" = "); - emitOperand(inst->getOperand(1), rightSide(prec, outerPrec)); - } - break; - case kIROp_Call: { emitCallExpr((IRCall*)inst, outerPrec); @@ -2393,6 +2383,22 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) } break; + case kIROp_Store: + { + if (inst->getPrevInst() == inst->getOperand(0) && inst->getOperand(0)->getOp() == kIROp_Var) + { + // If we are storing into a var that is defined right before the store, we have + // already folded the store in the initialization of the var, so we can skip here. + break; + } + auto prec = getInfo(EmitOp::Assign); + emitDereferenceOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), prec)); + m_writer->emit(" = "); + emitOperand(inst->getOperand(1), rightSide(prec, getInfo(EmitOp::General))); + m_writer->emit(";\n"); + } + break; + case kIROp_Param: // Don't emit parameters, since they are declared as part of the function. break; @@ -3321,6 +3327,15 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) emitLayoutSemantics(varDecl); + if (auto store = as<IRStore>(varDecl->getNextInst())) + { + if (store->getPtr() == varDecl) + { + m_writer->emit(" = "); + emitOperand(store->getVal(), getInfo(EmitOp::General)); + } + } + m_writer->emit(";\n"); } diff --git a/source/slang/slang-emit-source-writer.cpp b/source/slang/slang-emit-source-writer.cpp index bed9a2dbc..7692bd0ec 100644 --- a/source/slang/slang-emit-source-writer.cpp +++ b/source/slang/slang-emit-source-writer.cpp @@ -247,8 +247,21 @@ void SourceWriter::emit(double value) stream.setf(std::ios::fixed, std::ios::floatfield); stream.precision(20); stream << value; - - emit(stream.str().c_str()); + auto str = stream.str(); + auto slice = UnownedStringSlice(str.c_str()); + // Remove redundant trailing 0s. + if (slice.end() > slice.begin()) + { + auto lastChar = slice.end() - 1; + while (lastChar > slice.begin() && *lastChar == '0') + lastChar--; + if (*lastChar == '.') + lastChar++; + if (lastChar > slice.end() - 1) + lastChar = slice.end() - 1; + slice = slice.subString(0, lastChar - slice.begin() + 1); + } + emit(slice); } void SourceWriter::advanceToSourceLocationIfValid(const SourceLoc& sourceLocation) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 3d923179c..2ef0a5647 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -54,6 +54,7 @@ #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" #include "slang-ir-string-hash.h" +#include "slang-ir-simplify-for-emit.h" #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -1008,6 +1009,9 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr linkedIR)); auto irModule = linkedIR.module; + + // Perform final simplifications to help emit logic to generate more compact code. + simplifyForEmit(irModule); metadata = linkedIR.metadata; @@ -1015,15 +1019,6 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr // passes have been performed, we can emit target code from // the IR module. // - // TODO: do we want to emit directly from IR, or translate the - // IR back into AST for emission? -#if 0 - { - StringBuilder sb; - StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); - dumpIR(irModule, getIRDumpOptions(), sourceManager, &writer); - } -#endif sourceEmitter->emitModule(irModule, sink); } diff --git a/source/slang/slang-intrinsic-expand.cpp b/source/slang/slang-intrinsic-expand.cpp index 64ef4e761..de3396efb 100644 --- a/source/slang/slang-intrinsic-expand.cpp +++ b/source/slang/slang-intrinsic-expand.cpp @@ -735,14 +735,10 @@ const char* IntrinsicExpandContext::_emitSpecial(const char* cursor) Index argIndex = 0; SLANG_RELEASE_ASSERT(m_argCount > argIndex); auto arg = m_args[argIndex].get(); - auto argLoad = as<IRLoad>(arg); - SLANG_RELEASE_ASSERT(argLoad); - - auto argVar = argLoad->getOperand(0); // Find the associated decoration IRDecoration* foundDecoration = nullptr; - for (auto decoration : argVar->getDecorations()) + for (auto decoration : arg->getDecorations()) { const auto curKind = LocationTracker::getKindFromDecoration(decoration); if (curKind == kind) @@ -755,7 +751,7 @@ const char* IntrinsicExpandContext::_emitSpecial(const char* cursor) // Must have found the decoration SLANG_ASSERT(foundDecoration); - const auto location = m_emitter->getLocationTracker().getValue(kind, argVar, foundDecoration); + const auto location = m_emitter->getLocationTracker().getValue(kind, arg, foundDecoration); m_writer->emit(location); } } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 16b7a977c..a45a3abf9 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1162,6 +1162,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu inBuilder->addForwardDerivativeDecoration(origFunc, diffFunc); } + inBuilder->addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); + FuncBodyTranscriptionTask task; task.type = FuncBodyTranscriptionTaskType::Forward; task.originalFunc = origFunc; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index fed53b037..702f9819a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -12,6 +12,7 @@ #include "slang-ir-addr-inst-elimination.h" #include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-init-local-var.h" +#include "slang-ir-redundancy-removal.h" namespace Slang { @@ -305,7 +306,14 @@ namespace Slang IRFunc* primalFunc = origFunc; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - differentiableTypeConformanceContext.setFunc(origFunc); + + // The original func may not have a type dictionary if it is not originally marked as + // differentiable, in this case we would have already pulled the necessary types from + // the user-provided derivative function, so we are still fine. + if (origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + differentiableTypeConformanceContext.setFunc(origFunc); + } auto diffFunc = builder.createFunc(); @@ -334,7 +342,7 @@ namespace Slang builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); cloneInst(&cloneEnv, &builder, dictDecor); } - + builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); } @@ -588,43 +596,6 @@ namespace Slang return result; } - void eliminateRedundantLoad(IRFunc* func) - { - for (auto block : func->getBlocks()) - { - for (auto inst = block->getFirstInst(); inst;) - { - auto nextInst = inst->getNextInst(); - if (auto load = as<IRLoad>(inst)) - { - for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst()) - { - if (auto store = as<IRStore>(prev)) - { - if (store->getPtr() == load->getPtr()) - { - // If the load is preceeded by a store without any side-effect insts in-between, remove the load. - auto value = store->getVal(); - load->replaceUsesWith(value); - load->removeAndDeallocate(); - break; - } - } - else if (as<IRCall>(prev)) - { - break; - } - else if (prev->mightHaveSideEffects()) - { - break; - } - } - } - inst = nextInst; - } - } - } - // Create a copy of originalFunc's forward derivative in the same generic context (if any) of // `diffPropagateFunc`. IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( @@ -669,7 +640,7 @@ namespace Slang primalOuterParent->removeAndDeallocate(); // Remove redundant loads since they interfere with transposition logic. - eliminateRedundantLoad(fwdDiffFunc); + eliminateRedundantLoadStore(fwdDiffFunc); // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`. if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc))) @@ -976,13 +947,16 @@ namespace Slang { // Create dOut param. auto diffParam = builder->emitParam(diffType); + copyNameHintDecoration(diffParam, fwdParam); result.propagateFuncParams.Add(diffParam); primalRefReplacement = builder->emitParam(builder->getOutType(primalType)); + copyNameHintDecoration(primalRefReplacement, fwdParam); // Create a local var for read access in pre-transpose code. // This will the var from which we will fetch the final resulting derivative // after transposition. auto tempVar = nextBlockBuilder.emitVar(diffType); + copyNameHintDecoration(tempVar, fwdParam); nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType); // Initialize the var with input diff param at start. @@ -999,11 +973,13 @@ namespace Slang else { primalRefReplacement = builder->emitParam(outType); + copyNameHintDecoration(primalRefReplacement, fwdParam); } result.primalFuncParams.Add(primalRefReplacement); // Create a local var for the out param for the primal part of the prop func. auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType()); + copyNameHintDecoration(tempPrimalVar, fwdParam); result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar; instsToRemove.Add(fwdParam); @@ -1023,10 +999,13 @@ namespace Slang // Create an in param for the prop func. auto propParam = builder->emitParam(inoutType->getValueType()); + copyNameHintDecoration(propParam, fwdParam); result.propagateFuncParams.Add(propParam); // Create a local var for the out param for the primal part of the prop func. auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType()); + copyNameHintDecoration(tempPrimalVar, fwdParam); + result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar); auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam); result.propagateFuncSpecificPrimalInsts.add(storeInst); @@ -1054,8 +1033,11 @@ namespace Slang // Create inout version. auto inoutDiffPairType = builder->getInOutType(diffPairType); primalRefReplacement = builder->emitParam(primalType); + copyNameHintDecoration(primalRefReplacement, fwdParam); + result.primalFuncParams.Add(primalRefReplacement); auto propParam = builder->emitParam(inoutDiffPairType); + copyNameHintDecoration(propParam, fwdParam); result.propagateFuncParams.Add(propParam); // A reference to this parameter from the diff blocks should be replaced with a load @@ -1085,9 +1067,11 @@ namespace Slang // Process differentiable inout parameters. auto primalParam = builder->emitParam(builder->getInOutType(primalType)); + copyNameHintDecoration(primalParam, fwdParam); result.primalFuncParams.Add(primalParam); auto diffParam = builder->emitParam(inoutType); + copyNameHintDecoration(diffParam, fwdParam); result.propagateFuncParams.Add(diffParam); // Primal references to this param is the new primal param. @@ -1102,6 +1086,7 @@ namespace Slang // Create a local var for diff read access. auto diffVar = nextBlockBuilder.emitVar(diffType); + copyNameHintDecoration(diffVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(diffVar); diffBuilder.markInstAsDifferential(diffVar, diffPairType); diffRefReplacement = diffVar; @@ -1114,6 +1099,8 @@ namespace Slang // Create a local var for diff write access. auto diffWriteVar = nextBlockBuilder.emitVar(diffType); + copyNameHintDecoration(diffWriteVar, fwdParam); + // Initialize write var to 0. auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff); result.propagateFuncSpecificPrimalInsts.add(writeStore); @@ -1122,6 +1109,8 @@ namespace Slang // Create a local var for the primal logic in the propagate func. auto primalVar = nextBlockBuilder.emitVar(primalType); + copyNameHintDecoration(primalVar, fwdParam); + result.propagateFuncSpecificPrimalInsts.add(primalVar); auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam); result.propagateFuncSpecificPrimalInsts.add(initPrimalVal); @@ -1213,11 +1202,13 @@ namespace Slang SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); + builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut")); result.propagateFuncParams.Add(dOutParam); } // Add a parameter for intermediate val. auto ctxParam = builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); + builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx")); result.primalFuncParams.Add(ctxParam); result.propagateFuncParams.Add(ctxParam); result.dOutParam = dOutParam; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index be20d8aa8..3e7e346d2 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -295,6 +295,7 @@ struct ExtractPrimalFuncContext auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + oldIntermediateParam->transferDecorationsTo(outIntermediary); primalParams.Add(outIntermediary); oldIntermediateParam->replaceUsesWith(outIntermediary); oldIntermediateParam->removeAndDeallocate(); @@ -473,15 +474,34 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (inst->getOp() == kIROp_Var) { // This is a var for intermediate context. - auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType(); - auto val = builder.emitFieldExtract( - valType, - intermediateVar, - structKeyDecor->getStructKey()); - auto tempVar = - builder.emitVar(valType); - builder.emitStore(tempVar, val); - inst->replaceUsesWith(tempVar); + // Replace all loads of the var with a field extract. + // Other type of uses will get a temp var that stores a copy of the field. + while (auto use = inst->firstUse) + { + if (as<IRDecoration>(use->getUser())) + { + use->set(builder.getVoidValue()); + continue; + } + builder.setInsertBefore(use->getUser()); + auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType(); + auto val = builder.emitFieldExtract( + valType, + intermediateVar, + structKeyDecor->getStructKey()); + if (use->getUser()->getOp() == kIROp_Load) + { + use->getUser()->replaceUsesWith(val); + use->getUser()->removeAndDeallocate(); + } + else + { + auto tempVar = + builder.emitVar(valType); + builder.emitStore(tempVar, val); + use->set(tempVar); + } + } } else { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index fdaff4960..f38bdfdbd 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -256,6 +256,11 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I builder.setInsertBefore(diffType); auto pairStructType = builder.createStructType(); + StringBuilder nameBuilder; + nameBuilder << "DiffPair_"; + getTypeNameHint(nameBuilder, origBaseType); + builder.addNameHintDecoration(pairStructType, nameBuilder.ToString().getUnownedSlice()); + builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType); builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType); return pairStructType; diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 33e5b3cb4..337caa246 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" namespace Slang { @@ -16,6 +17,7 @@ struct DeadCodeEliminationContext // `eliminateDeadCode` function. // IRModule* module; + IRDeadCodeEliminationOptions options; // If we removed an inst, there may be still "weak references" to the inst. @@ -129,6 +131,9 @@ struct DeadCodeEliminationContext auto inst = workList.getLast(); workList.removeLast(); + if (!isChildInstOf(inst, root)) + continue; + // At this point we know that `inst` is live, // and we want to start considering which other // instructions must be live because of that @@ -426,7 +431,6 @@ bool eliminateDeadCode( DeadCodeEliminationContext context; context.module = module; context.options = options; - return context.processModule(); } @@ -437,7 +441,6 @@ bool eliminateDeadCode( DeadCodeEliminationContext context; context.module = root->getModule(); context.options = options; - return context.processInst(root); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 1cb839751..26a92a17a 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -797,6 +797,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /* Differentiable Type Dictionary */ INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT) + /// Overrides the floating mode for the target function + INST(FloatingPointModeOverrideDecoration, FloatingPointModeOverride, 1, 0) + /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d1374477f..fad20e900 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -846,6 +846,13 @@ struct IRDifferentiableTypeDictionaryDecoration : IRDecoration IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration) }; +struct IRFloatingModeOverrideDecoration : IRDecoration +{ + IR_LEAF_ISA(FloatingPointModeOverrideDecoration) + + FloatingPointMode getFloatingPointMode() { return (FloatingPointMode)cast<IRIntLit>(getOperand(0))->getValue(); } +}; + // An instruction that specializes another IR value // (representing a generic) to a particular set of generic arguments // (instructions representing types, witness tables, etc.) @@ -2835,6 +2842,8 @@ public: // Add a differentiable type entry to the appropriate dictionary. IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness); + IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode); + IRInst* emitSpecializeInst( IRType* type, IRInst* genericVal, diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index fd0b4577a..65b4adcac 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -10,6 +10,7 @@ struct PeepholeContext : InstPassBase {} bool changed = false; + FloatingPointMode floatingPointMode = FloatingPointMode::Precise; bool tryFoldElementExtractFromUpdateInst(IRInst* inst) { @@ -96,8 +97,158 @@ struct PeepholeContext : InstPassBase return false; } + bool isZero(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_IntLit: + return as<IRIntLit>(inst)->getValue() == 0; + case kIROp_FloatLit: + return as<IRFloatLit>(inst)->getValue() == 0.0; + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isZero(inst->getOperand(i))) + { + return false; + } + } + return true; + } + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + return isZero(inst->getOperand(0)); + default: + return false; + } + } + + bool isOne(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_IntLit: + return as<IRIntLit>(inst)->getValue() == 1; + case kIROp_FloatLit: + return as<IRFloatLit>(inst)->getValue() == 1.0; + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (!isOne(inst->getOperand(i))) + { + return false; + } + } + return true; + } + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + return isOne(inst->getOperand(0)); + default: + return false; + } + } + + bool tryOptimizeArithmeticInst(IRInst* inst) + { + bool allowUnsafeOptimizations = + (floatingPointMode == FloatingPointMode::Fast || + isIntegralScalarOrCompositeType(inst->getDataType())); + + auto tryReplace = [&](IRInst* replacement) -> bool + { + if (replacement->getFullType() != inst->getFullType()) + { + // If the operand type is different from result type, + // we try to convert for some known cases. + if (auto vectorType = as<IRVectorType>(inst->getFullType())) + { + if (vectorType->getElementType() != replacement->getFullType()) + return false; + IRBuilder builder(sharedBuilderStorage); + builder.setInsertBefore(inst); + replacement = builder.emitMakeVectorFromScalar(inst->getFullType(), replacement); + } + else + { + return false; + } + } + + inst->replaceUsesWith(replacement); + inst->removeAndDeallocate(); + return true; + }; + + switch (inst->getOp()) + { + case kIROp_Add: + if (isZero(inst->getOperand(0))) + { + return tryReplace(inst->getOperand(1)); + } + else if (isZero(inst->getOperand(1))) + { + return tryReplace(inst->getOperand(0)); + } + break; + case kIROp_Sub: + if (isZero(inst->getOperand(1))) + { + return tryReplace(inst->getOperand(0)); + } + break; + case kIROp_Mul: + if (isOne(inst->getOperand(0))) + { + return tryReplace(inst->getOperand(1)); + } + else if (isOne(inst->getOperand(1))) + { + return tryReplace(inst->getOperand(0)); + } + else if (allowUnsafeOptimizations && isZero(inst->getOperand(0))) + { + return tryReplace(inst->getOperand(0)); + } + else if (allowUnsafeOptimizations && isZero(inst->getOperand(1))) + { + return tryReplace(inst->getOperand(1)); + } + break; + case kIROp_Div: + if (allowUnsafeOptimizations && isZero(inst->getOperand(0))) + { + return tryReplace(inst->getOperand(0)); + } + else if (isOne(inst->getOperand(1))) + { + return tryReplace(inst->getOperand(0)); + } + } + return false; + } + void processInst(IRInst* inst) { + if (as<IRGlobalValueWithCode>(inst)) + { + if (auto fpModeDecor = inst->findDecoration<IRFloatingModeOverrideDecoration>()) + floatingPointMode = fpModeDecor->getFloatingPointMode(); + } + switch (inst->getOp()) { case kIROp_GetResultError: @@ -432,6 +583,12 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + changed = tryOptimizeArithmeticInst(inst); + break; default: break; } @@ -443,6 +600,7 @@ struct PeepholeContext : InstPassBase sharedBuilder->init(module); sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); bool result = false; + for (;;) { changed = false; diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 9bd681115..176142601 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -109,6 +109,7 @@ bool removeRedundancy(IRModule* module) if (auto func = as<IRFunc>(inst)) { changed |= removeRedundancyInFunc(func); + changed |= eliminateRedundantLoadStore(func); } } return changed; @@ -126,4 +127,95 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func) return context.removeRedundancyInBlock(deduplicateCtx, root); } +bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func) +{ + bool changed = false; + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst;) + { + auto nextInst = inst->getNextInst(); + if (auto load = as<IRLoad>(inst)) + { + for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst()) + { + if (auto store = as<IRStore>(prev)) + { + if (store->getPtr() == load->getPtr()) + { + // If the load is preceeded by a store without any side-effect insts in-between, remove the load. + auto value = store->getVal(); + load->replaceUsesWith(value); + load->removeAndDeallocate(); + changed = true; + break; + } + } + + if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr())) + { + break; + } + } + } + else if (auto store = as<IRStore>(inst)) + { + // We perform a quick and conservative check: + // A store is redundant if it is followed by another store to the same address in + // the same basic block, and there are no instructions that may use any addresses + // related to this address. + bool hasAddrUse = false; + bool hasOverridingStore = false; + + // Stores to global variables will never get removed. + if (!isChildInstOf(store->getPtr(), func)) + hasAddrUse = true; + + for (auto next = store->getNextInst(); next; next = next->getNextInst()) + { + if (auto nextStore = as<IRStore>(next)) + { + if (nextStore->getPtr() == store->getPtr()) + { + hasOverridingStore = true; + break; + } + } + + // If we see any insts that have reads or modifies the address before seeing + // an overriding store, don't remove the store. + // We can make the test more accurate by collecting all addresses related to + // the target address first, and only bail out if any of the related addresses + // are involved. + switch (next->getOp()) + { + case kIROp_Load: + if (canAddressesPotentiallyAlias(func, next->getOperand(0), store->getPtr())) + { + hasAddrUse = true; + } + break; + default: + if (canInstHaveSideEffectAtAddress(func, next, store->getPtr())) + { + hasAddrUse = true; + } + break; + } + if (hasAddrUse) + break; + } + + if (!hasAddrUse && hasOverridingStore) + { + store->removeAndDeallocate(); + changed = true; + } + } + inst = nextInst; + } + } + return changed; +} + } diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h index 26b265e77..c2df7853e 100644 --- a/source/slang/slang-ir-redundancy-removal.h +++ b/source/slang/slang-ir-redundancy-removal.h @@ -8,4 +8,6 @@ namespace Slang bool removeRedundancy(IRModule* module); bool removeRedundancyInFunc(IRGlobalValueWithCode* func); + + bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func); } diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp new file mode 100644 index 000000000..5e5f61a4a --- /dev/null +++ b/source/slang/slang-ir-simplify-for-emit.cpp @@ -0,0 +1,354 @@ +#include "slang-ir-simplify-for-emit.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-util.h" + +namespace Slang +{ + +struct SimplifyForEmitContext : public InstPassBase +{ + SimplifyForEmitContext(IRModule* inModule) + : InstPassBase(inModule) + {} + + List<IRInst*> followUpWorkList; + HashSet<IRInst*> followUpWorkListSet; + + void addToFollowUpWorkList(IRInst* inst) + { + if (followUpWorkListSet.Add(inst)) + followUpWorkList.add(inst); + } + + void processMakeStruct(IRInst* makeStruct) + { + auto structType = as<IRStructType>(makeStruct->getDataType()); + if (!structType) + return; + for (auto use = makeStruct->firstUse; use;) + { + auto nextUse = use->nextUse; + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + IRBuilder builder(sharedBuilderStorage); + builder.setInsertBefore(user); + UInt i = 0; + for (auto field : structType->getFields()) + { + auto fieldAddr = builder.emitFieldAddress( + builder.getPtrType(field->getFieldType()), + store->getPtr(), + field->getKey()); + builder.emitStore(fieldAddr, makeStruct->getOperand(i)); + addToFollowUpWorkList(makeStruct->getOperand(i)); + i++; + } + store->removeAndDeallocate(); + } + use = nextUse; + } + if (!makeStruct->hasUses()) + makeStruct->removeAndDeallocate(); + } + + void processMakeArray(IRInst* makeArray) + { + auto arrayType = as<IRArrayType>(makeArray->getDataType()); + if (!arrayType) + return; + + for (auto use = makeArray->firstUse; use;) + { + auto nextUse = use->nextUse; + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + IRBuilder builder(sharedBuilderStorage); + builder.setInsertBefore(user); + for (UInt i = 0; i < makeArray->getOperandCount(); i++) + { + auto elementAddr = builder.emitElementAddress( + builder.getPtrType(arrayType->getElementType()), + store->getPtr(), + builder.getIntValue(builder.getIntType(), (IRIntegerValue)i)); + builder.emitStore(elementAddr, makeArray->getOperand(i)); + addToFollowUpWorkList(makeArray->getOperand(i)); + } + store->removeAndDeallocate(); + } + use = nextUse; + } + if (!makeArray->hasUses()) + makeArray->removeAndDeallocate(); + } + + void processMakeArrayFromElement(IRInst* makeArray) + { + auto arrayType = as<IRArrayType>(makeArray->getDataType()); + if (!arrayType) + return; + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) + return; + + for (auto use = makeArray->firstUse; use;) + { + auto nextUse = use->nextUse; + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + IRBuilder builder(sharedBuilderStorage); + builder.setInsertBefore(user); + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + auto elementAddr = builder.emitElementAddress( + builder.getPtrType(arrayType->getElementType()), + store->getPtr(), + builder.getIntValue(builder.getIntType(), i)); + builder.emitStore(elementAddr, makeArray->getOperand(0)); + addToFollowUpWorkList(makeArray->getOperand(0)); + } + store->removeAndDeallocate(); + } + use = nextUse; + } + if (!makeArray->hasUses()) + makeArray->removeAndDeallocate(); + } + + void processLoadUse(IRGlobalValueWithCode* func, IRLoad* load, IRUse* use) + { + auto user = use->getUser(); + if (user->getParent() != load->getParent()) + return; + for (auto inst = load->getNextInst(); inst; inst = inst->getNextInst()) + { + if (inst == user) + break; + if (canInstHaveSideEffectAtAddress(func, inst, load->getPtr())) + return; + } + + // If we reach here, it is OK to defer the load at use site. + IRBuilder builder(sharedBuilderStorage); + builder.setInsertBefore(user); + auto newLoad = builder.emitLoad(load->getPtr()); + use->set(newLoad); + } + + void processLoad(IRLoad* inst) + { + auto func = getParentFunc(inst); + if (!func) + return; + + for (auto use = inst->firstUse; use;) + { + auto nextUse = use->nextUse; + processLoadUse(func, inst, use); + use = nextUse; + } + + if (!inst->hasUses()) + inst->removeAndDeallocate(); + } + + void processElementExtract(IRInst* inst) + { + // Create a duplicate for each use site. + // This is safe because the result value of this inst should never + // change regardless of where the inst is defined. + // By creating the duplicates right before use sites, we will enable + // the emit logic to always fold these insts. + for (auto use = inst->firstUse; use;) + { + auto nextUse = use->nextUse; + + auto user = use->getUser(); + if (user->getPrevInst() == inst) + { + use = nextUse; + continue; + } + + IRBuilder builder(sharedBuilderStorage); + builder.setInsertBefore(user); + List<IRInst*> args; + for (UInt i = 0; i < inst->getOperandCount(); i++) + args.add(inst->getOperand(i)); + auto newInst = builder.emitIntrinsicInst(inst->getFullType(), inst->getOp(), inst->getOperandCount(), args.getBuffer()); + use->set(newInst); + + use = nextUse; + } + if (!inst->hasUses()) + inst->removeAndDeallocate(); + } + + void processVar(IRInst* var) + { + // Defer var to its first use, if the use is in the same basic block as the var. + HashSet<IRInst*> userInSameBlock; + for (auto use = var->firstUse; use; use = use->nextUse) + if (use->getUser()->getParent() == var->getParent()) + { + userInSameBlock.Add(use->getUser()); + } + IRInst* firstUser = nullptr; + for (auto inst = var->getNextInst(); inst; inst = inst->getNextInst()) + { + if (userInSameBlock.Contains(inst)) + { + firstUser = inst; + break; + } + } + if (!firstUser) + return; + var->insertBefore(firstUser); + } + + void processInst(IRInst* inst) + { + // We inspect each inst and see if the following simplifications + // can be applied: + // 1. If we see `store(addr, MakeArray/Struct)`, we should turn them + // into direct stores into each element/field and remove the need + // to create a temporary for the `MakeArray/Struct` inst. + // 2. If we see `load(addr)`, we duplicate the load right at each + // use site if it can be determined safe to do so. This allows + // emit logic to skip producing a temp var for the loaded result. + switch (inst->getOp()) + { + case kIROp_MakeStruct: + processMakeStruct(inst); + break; + case kIROp_MakeArray: + processMakeArray(inst); + break; + case kIROp_MakeArrayFromElement: + processMakeArrayFromElement(inst); + break; + case kIROp_Load: + processLoad(as<IRLoad>(inst)); + break; + case kIROp_GetElement: + case kIROp_FieldExtract: + processElementExtract(inst); + break; + case kIROp_Var: + processVar(inst); + break; + } + } + + void eliminateCompositeConstruct(IRGlobalValueWithCode* func) + { + followUpWorkList.clear(); + followUpWorkListSet.Clear(); + + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + addToFollowUpWorkList(inst); + break; + } + } + } + for (Index i = 0; i < followUpWorkList.getCount(); i++) + processInst(followUpWorkList[i]); + } + + void deferAndDuplicateLoad(IRGlobalValueWithCode* func) + { + followUpWorkList.clear(); + followUpWorkListSet.Clear(); + + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Load: + addToFollowUpWorkList(inst); + break; + } + } + } + for (Index i = 0; i < followUpWorkList.getCount(); i++) + processInst(followUpWorkList[i]); + } + + void deferVarDecl(IRGlobalValueWithCode* func) + { + followUpWorkList.clear(); + followUpWorkListSet.Clear(); + + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Var: + addToFollowUpWorkList(inst); + break; + } + } + } + for (Index i = 0; i < followUpWorkList.getCount(); i++) + processInst(followUpWorkList[i]); + } + + void deferAndDuplicateElementExtract(IRGlobalValueWithCode* func) + { + followUpWorkList.clear(); + followUpWorkListSet.Clear(); + + for (auto block = func->getLastBlock(); block; block = block->getPrevBlock()) + { + for (auto inst = block->getLastChild(); inst; inst = inst->getPrevInst()) + { + switch (inst->getOp()) + { + case kIROp_GetElement: + case kIROp_FieldExtract: + addToFollowUpWorkList(inst); + break; + } + } + } + for (Index i = 0; i < followUpWorkList.getCount(); i++) + processInst(followUpWorkList[i]); + } + + void processFunc(IRGlobalValueWithCode* func) + { + eliminateCompositeConstruct(func); + deferAndDuplicateElementExtract(func); + deferAndDuplicateLoad(func); + deferVarDecl(func); + } + + void processModule() + { + sharedBuilderStorage.init(module); + processInstsOfType<IRFunc>(kIROp_Func, [this](IRFunc* f) { processFunc(f); }); + } +}; + +void simplifyForEmit(IRModule* module) +{ + SimplifyForEmitContext context(module); + context.processModule(); +} + +} diff --git a/source/slang/slang-ir-simplify-for-emit.h b/source/slang/slang-ir-simplify-for-emit.h new file mode 100644 index 000000000..a6cf3bad8 --- /dev/null +++ b/source/slang/slang-ir-simplify-for-emit.h @@ -0,0 +1,9 @@ +// slang-ir-simplify-for-emit.h +#pragma once + +namespace Slang +{ + struct IRModule; + + void simplifyForEmit(IRModule* inModule); +} diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 5cf074484..af6fd8ac4 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -229,6 +229,217 @@ String dumpIRToString(IRInst* root) return sb.ToString(); } +void copyNameHintDecoration(IRInst* dest, IRInst* src) +{ + auto decor = src->findDecoration<IRNameHintDecoration>(); + if (decor) + { + cloneDecoration(decor, dest); + } +} + +void getTypeNameHint(StringBuilder& sb, IRInst* type) +{ + if (!type) + return; + + switch (type->getOp()) + { + case kIROp_FloatType: + sb << "float"; + break; + case kIROp_HalfType: + sb << "half"; + break; + case kIROp_DoubleType: + sb << "double"; + break; + case kIROp_IntType: + sb << "int"; + break; + case kIROp_Int8Type: + sb << "int8"; + break; + case kIROp_Int16Type: + sb << "int16"; + break; + case kIROp_Int64Type: + sb << "int64"; + break; + case kIROp_IntPtrType: + sb << "intptr"; + break; + case kIROp_UIntType: + sb << "uint"; + break; + case kIROp_UInt8Type: + sb << "uint8"; + break; + case kIROp_UInt16Type: + sb << "uint16"; + break; + case kIROp_UInt64Type: + sb << "uint64"; + break; + case kIROp_UIntPtrType: + sb << "uintptr"; + break; + case kIROp_CharType: + sb << "char"; + break; + case kIROp_StringType: + sb << "string"; + break; + case kIROp_ArrayType: + sb << "array_"; + getTypeNameHint(sb, type->getOperand(0)); + break; + case kIROp_VectorType: + getTypeNameHint(sb, type->getOperand(0)); + getTypeNameHint(sb, as<IRVectorType>(type)->getElementCount()); + break; + case kIROp_MatrixType: + getTypeNameHint(sb, type->getOperand(0)); + getTypeNameHint(sb, as<IRMatrixType>(type)->getRowCount()); + sb << "x"; + getTypeNameHint(sb, as<IRMatrixType>(type)->getColumnCount()); + break; + case kIROp_IntLit: + sb << as<IRIntLit>(type)->getValue(); + break; + default: + if (auto decor = type->findDecoration<IRNameHintDecoration>()) + sb << decor->getName(); + break; + } +} + +static IRInst* _getRootAddr(IRInst* addr) +{ + for (;;) + { + switch (addr->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + addr = addr->getOperand(0); + continue; + default: + break; + } + break; + } + return addr; +} + +// A simple and conservative address aliasing check. +bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IRInst* addr2) +{ + if (addr1 == addr2) + return true; + + // Two variables can never alias. + addr1 = _getRootAddr(addr1); + addr2 = _getRootAddr(addr2); + + // Global addresses can alias with anything. + if (!isChildInstOf(addr1, func)) + return true; + + if (!isChildInstOf(addr2, func)) + return true; + + if (addr1->getOp() == kIROp_Var && addr2->getOp() == kIROp_Var + && addr1 != addr2) + return false; + + // A param and a var can never alias. + if (addr1->getOp() == kIROp_Param && addr1->getParent() == func->getFirstBlock() && + addr2->getOp() == kIROp_Var || + addr1->getOp() == kIROp_Var && addr2->getOp() == kIROp_Param && + addr2->getParent() == func->getFirstBlock()) + return false; + return true; +} + +bool isPtrLikeOrHandleType(IRInst* type) +{ + switch (type->getOp()) + { + case kIROp_ComPtrType: + case kIROp_RawPointerType: + case kIROp_RTTIPointerType: + case kIROp_PseudoPtrType: + case kIROp_OutType: + case kIROp_InOutType: + case kIROp_PtrType: + case kIROp_RefType: + return true; + } + return false; +} + +bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, IRInst* addr) +{ + switch (inst->getOp()) + { + case kIROp_Store: + // If the target of the store inst may overlap addr, return true. + if (canAddressesPotentiallyAlias(func, as<IRStore>(inst)->getPtr(), addr)) + return true; + break; + case kIROp_Call: + { + auto call = as<IRCall>(inst); + + // If addr is a global variable, calling a function may change its value. + // So we need to return true here to be conservative. + if (!isChildInstOf(_getRootAddr(addr), func)) + { + auto callee = call->getCallee(); + if (callee && + callee->findDecoration<IRReadNoneDecoration>() && + callee->findDecoration<IRNoSideEffectDecoration>()) + { + // An exception is if the callee is side-effect free and is not reading from + // memory. + } + else + { + return true; + } + } + + // If any pointer typed argument of the call inst may overlap addr, return true. + for (UInt i = 0; i < call->getArgCount(); i++) + { + if (isPtrLikeOrHandleType(call->getArg(i)->getDataType())) + { + if (canAddressesPotentiallyAlias(func, call->getArg(i), addr)) + return true; + } + } + } + break; + case kIROp_CastPtrToInt: + case kIROp_Reinterpret: + case kIROp_BitCast: + { + // If we are trying to cast an address to something else, return true. + if (isPtrLikeOrHandleType(inst->getOperand(0)->getDataType()) && + canAddressesPotentiallyAlias(func, inst->getOperand(0), addr)) + return true; + } + break; + default: + // Default behavior is that any insts that have side effect may affect `addr`. + if (inst->mightHaveSideEffects()) + return true; + break; + } + return false; +} + bool isPureFunctionalCall(IRCall* call) { auto callee = getResolvedInstForDecorations(call->getCallee()); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 26ad4bc68..c067bde44 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -153,11 +153,20 @@ inline IRInst* unwrapAttributedType(IRInst* type) return type; } +void getTypeNameHint(StringBuilder& sb, IRInst* type); +void copyNameHintDecoration(IRInst* dest, IRInst* src); + +bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IRInst* addr2); + String dumpIRToString(IRInst* root); // Returns whether a call insts can be treated as a pure functional inst // (no writes to memory, no side effects). bool isPureFunctionalCall(IRCall* callInst); + +bool isPtrLikeOrHandleType(IRInst* type); + +bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, IRInst* addr); } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 558574bf6..1b16bfe1f 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4604,6 +4604,14 @@ namespace Slang return inst; } + IRInst* IRBuilder::addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode) + { + return addDecoration( + dest, + kIROp_FloatingPointModeOverrideDecoration, + getIntValue(getIntType(), (IRIntegerValue)mode)); + } + IRInst* IRBuilder::emitSwizzle( IRType* type, IRInst* base, @@ -6418,6 +6426,20 @@ namespace Slang return false; } + bool isIntegralScalarOrCompositeType(IRType* t) + { + if (!t) + return false; + switch (t->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + return isIntegralType((IRType*)t->getOperand(0)); + default: + return isIntegralType(t); + } + } + void findAllInstsBreadthFirst(IRInst* inst, List<IRInst*>& outInsts) { Index index = outInsts.getCount(); @@ -6577,6 +6599,8 @@ namespace Slang void IRInst::insertBefore(IRInst* other) { SLANG_ASSERT(other); + if (other->getPrevInst() == this) + return; _insertAt(other->getPrevInst(), other, other->getParent()); } diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index e22e41f0c..41b140972 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -869,6 +869,8 @@ bool isTypeEqual(IRType* a, IRType* b); // True if this is an integral IRBasicType, not including Char or Ptr types bool isIntegralType(IRType* t); +bool isIntegralScalarOrCompositeType(IRType* t); + void findAllInstsBreadthFirst(IRInst* inst, List<IRInst*>& outInsts); // Constant Instructions @@ -943,6 +945,13 @@ struct IRIntLit : IRConstant IR_LEAF_ISA(IntLit); }; +struct IRFloatLit : IRConstant +{ + IRFloatingPointValue getValue() { return value.floatVal; } + + IR_LEAF_ISA(FloatLit); +}; + struct IRBoolLit : IRConstant { bool getValue() { return value.intVal != 0; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 149f5f6b9..8377246fb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3304,6 +3304,15 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } } + LoweredValInfo visitMakeRefExpr(MakeRefExpr* expr) + { + auto loweredBase = lowerLValueExpr(context, expr->base); + + SLANG_ASSERT(loweredBase.flavor == LoweredValInfo::Flavor::Ptr); + loweredBase.flavor = LoweredValInfo::Flavor::Simple; + return loweredBase; + } + LoweredValInfo visitParenExpr(ParenExpr* expr) { return lowerSubExpr(expr->base); @@ -4234,11 +4243,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> switch (baseVal.flavor) { case LoweredValInfo::Flavor::Simple: - return LoweredValInfo::simple( - builder->emitElementExtract( - type, - getSimpleVal(context, baseVal), - indexVal)); + return LoweredValInfo::simple( + builder->emitElementExtract( + type, + getSimpleVal(context, baseVal), + indexVal)); case LoweredValInfo::Flavor::Ptr: return LoweredValInfo::ptr( @@ -4416,7 +4425,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitOpenRefExpr(OpenRefExpr* expr) { - return lowerLValueExpr(context, expr->innerExpr); + auto info = lowerRValueExpr(context, expr->innerExpr); + SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(info.val->getFullType())); + SLANG_RELEASE_ASSERT(info.flavor == LoweredValInfo::Flavor::Simple); + info.flavor = LoweredValInfo::Flavor::Ptr; + return info; } }; @@ -4596,10 +4609,6 @@ LoweredValInfo lowerLValueExpr( LValueExprLoweringVisitor visitor; visitor.context = context; auto info = visitor.dispatch(expr); - if (as<RefType>(expr->type)) - { - info.flavor = LoweredValInfo::Flavor::Ptr; - } return info; } @@ -4612,10 +4621,6 @@ LoweredValInfo lowerRValueExpr( RValueExprLoweringVisitor visitor; visitor.context = context; auto info = visitor.dispatch(expr); - if (as<RefType>(expr->type)) - { - info.val = context->irBuilder->emitLoad(info.val); - } return info; } |
