summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-07 18:36:35 -0800
committerGitHub <noreply@github.com>2023-02-07 18:36:35 -0800
commit4be623c52a6518eb86756a0369706c1d6670f6bb (patch)
treec24f54e34db9f1f02c2d51808b15121eba9195a9 /source
parent101f164b036d0c1c012243df69179559b6f40fb8 (diff)
Arithmetic simplifications and more IR clean up logic. (#2632)
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang8
-rw-r--r--source/slang/slang-ast-expr.h7
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-check-conversion.cpp62
-rw-r--r--source/slang/slang-check-expr.cpp1
-rw-r--r--source/slang/slang-check-impl.h1
-rw-r--r--source/slang/slang-emit-c-like.cpp37
-rw-r--r--source/slang/slang-emit-source-writer.cpp17
-rw-r--r--source/slang/slang-emit.cpp13
-rw-r--r--source/slang/slang-intrinsic-expand.cpp8
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp71
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp38
-rw-r--r--source/slang/slang-ir-autodiff.cpp5
-rw-r--r--source/slang/slang-ir-dce.cpp7
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir-peephole.cpp158
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp92
-rw-r--r--source/slang/slang-ir-redundancy-removal.h2
-rw-r--r--source/slang/slang-ir-simplify-for-emit.cpp354
-rw-r--r--source/slang/slang-ir-simplify-for-emit.h9
-rw-r--r--source/slang/slang-ir-util.cpp211
-rw-r--r--source/slang/slang-ir-util.h9
-rw-r--r--source/slang/slang-ir.cpp24
-rw-r--r--source/slang/slang-ir.h9
-rw-r--r--source/slang/slang-lower-to-ir.cpp33
27 files changed, 1089 insertions, 102 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index f5b138f25..6c7a2c1d2 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -4793,7 +4793,7 @@ void __executeCallable(uint shaderIndex, int payloadLocation);
__generic<Payload>
__target_intrinsic(__glslRayTracing, "$XC")
[__readNone]
-int __callablePayloadLocation(Payload payload);
+int __callablePayloadLocation(__ref Payload payload);
// Now we provide a hard-coded definition of `CallShader()` for GLSL-based
// targets, which maps the generic HLSL operation into the non-generic
@@ -4848,7 +4848,7 @@ void __traceRay(
__generic<Payload>
__target_intrinsic(__glslRayTracing, "$XP")
[__readNone]
-int __rayPayloadLocation(Payload payload);
+int __rayPayloadLocation(__ref Payload payload);
__generic<payload_t>
__specialized_for_target(glsl)
@@ -5678,7 +5678,7 @@ struct VkSubpassInputMS<T>
// We access the HitObjectAttributes via this function for the desired type, and it acts *as if* it's just an access
// to the global t.
[ForceInline]
-__ref T __hitObjectAttributes<T>()
+Ref<T> __hitObjectAttributes<T>()
{
[__vulkanHitObjectAttributes]
static T t;
@@ -5691,7 +5691,7 @@ __ref T __hitObjectAttributes<T>()
__generic<Payload>
__target_intrinsic(__glslRayTracing, "$XH")
[__readNone]
-int __hitObjectAttributesLocation(Payload payload);
+int __hitObjectAttributesLocation(__ref Payload payload);
/// Immutable data type representing a ray hit or a miss. Can be used to invoke hit or miss shading,
/// or as a key in ReorderThread. Created by one of several methods described below. HitObject
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 9fdc10807..d49db89a2 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -268,6 +268,13 @@ class SwizzleExpr: public Expr
SourceLoc memberOpLoc;
};
+// An operation to convert an l-value to a reference type.
+class MakeRefExpr : public Expr
+{
+ SLANG_AST_CLASS(MakeRefExpr)
+ Expr* base = nullptr;
+};
+
// A dereference of a pointer or pointer-like type
class DerefExpr: public Expr
{
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 21611bcb1..5fd9df400 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -78,6 +78,7 @@ namespace Slang
// Conversion from a buffer to the type it carries needs to add a minimal
// extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>`
// from one on `Foo`
+ kConversionCost_GetRef = 5,
kConversionCost_ImplicitDereference = 10,
// Conversions based on explicit sub-typing relationships are the cheapest
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 6decce625..71231b9e5 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -642,11 +642,6 @@ namespace Slang
return true;
}
- if (auto refType = as<RefType>(toType))
- {
- return _coerce(refType->getValueType(), outToExpr, fromType, fromExpr, outCost);
- }
-
// If both are string types we assume they are convertable in both directions
if (as<StringTypeBase>(fromType) && as<StringTypeBase>(toType))
{
@@ -860,6 +855,63 @@ namespace Slang
return true;
}
+ if (auto refType = as<RefType>(toType))
+ {
+ if (!refType->getValueType()->equals(fromType))
+ return false;
+ if (!fromExpr->type.isLeftValue)
+ return false;
+
+ ConversionCost subCost = kConversionCost_GetRef;
+
+ MakeRefExpr* refExpr = nullptr;
+ if (outToExpr)
+ {
+ refExpr = m_astBuilder->create<MakeRefExpr>();
+ refExpr->base = fromExpr;
+ refExpr->type = QualType(refType);
+ refExpr->type.isLeftValue = false;
+ *outToExpr = refExpr;
+ }
+ if (outCost)
+ *outCost = subCost;
+ return true;
+ }
+
+
+ // Allow implicit dereferencing a reference type.
+ if (auto fromRefType = as<RefType>(fromType))
+ {
+ auto fromValueType = fromRefType->getValueType();
+
+ // If we convert, e.g., `ConstantBuffer<A> to `A`, we will allow
+ // subsequent conversion of `A` to `B` if such a conversion
+ // is possible.
+ //
+ ConversionCost subCost = kConversionCost_None;
+
+ Expr* openRefExpr = nullptr;
+ if (outToExpr)
+ {
+ openRefExpr = maybeOpenRef(fromExpr);
+ }
+
+ if (!_coerce(
+ toType,
+ outToExpr,
+ fromValueType,
+ openRefExpr,
+ &subCost))
+ {
+ return false;
+ }
+
+ if (outCost)
+ *outCost = subCost + kConversionCost_ImplicitDereference;
+ return true;
+ }
+
+
// The main general-purpose approach for conversion is
// using suitable marked initializer ("constructor")
// declarations on the target type.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 7e2bc3822..7d546e60b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1863,6 +1863,7 @@ namespace Slang
Expr* SemanticsVisitor::checkAssignWithCheckedOperands(AssignExpr* expr)
{
+ expr->left = maybeOpenRef(expr->left);
auto type = expr->left->type;
auto right = maybeOpenRef(expr->right);
expr->right = coerce(type, right);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ac5fc8392..719706635 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1928,6 +1928,7 @@ namespace Slang
}
CASE(DerefExpr)
+ CASE(MakeRefExpr)
CASE(MatrixSwizzleExpr)
CASE(SwizzleExpr)
CASE(OverloadedExpr)
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 160585e26..c664449e5 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1082,6 +1082,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
case kIROp_Param:
case kIROp_Func:
case kIROp_Alloca:
+ case kIROp_Store:
return false;
// Never fold these, because their result cannot be computed
@@ -1997,17 +1998,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
}
break;
- case kIROp_Store:
- {
- auto prec = getInfo(EmitOp::Assign);
- needClose = maybeEmitParens(outerPrec, prec);
-
- emitDereferenceOperand(inst->getOperand(0), leftSide(outerPrec, prec));
- m_writer->emit(" = ");
- emitOperand(inst->getOperand(1), rightSide(prec, outerPrec));
- }
- break;
-
case kIROp_Call:
{
emitCallExpr((IRCall*)inst, outerPrec);
@@ -2393,6 +2383,22 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst)
}
break;
+ case kIROp_Store:
+ {
+ if (inst->getPrevInst() == inst->getOperand(0) && inst->getOperand(0)->getOp() == kIROp_Var)
+ {
+ // If we are storing into a var that is defined right before the store, we have
+ // already folded the store in the initialization of the var, so we can skip here.
+ break;
+ }
+ auto prec = getInfo(EmitOp::Assign);
+ emitDereferenceOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), prec));
+ m_writer->emit(" = ");
+ emitOperand(inst->getOperand(1), rightSide(prec, getInfo(EmitOp::General)));
+ m_writer->emit(";\n");
+ }
+ break;
+
case kIROp_Param:
// Don't emit parameters, since they are declared as part of the function.
break;
@@ -3321,6 +3327,15 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl)
emitLayoutSemantics(varDecl);
+ if (auto store = as<IRStore>(varDecl->getNextInst()))
+ {
+ if (store->getPtr() == varDecl)
+ {
+ m_writer->emit(" = ");
+ emitOperand(store->getVal(), getInfo(EmitOp::General));
+ }
+ }
+
m_writer->emit(";\n");
}
diff --git a/source/slang/slang-emit-source-writer.cpp b/source/slang/slang-emit-source-writer.cpp
index bed9a2dbc..7692bd0ec 100644
--- a/source/slang/slang-emit-source-writer.cpp
+++ b/source/slang/slang-emit-source-writer.cpp
@@ -247,8 +247,21 @@ void SourceWriter::emit(double value)
stream.setf(std::ios::fixed, std::ios::floatfield);
stream.precision(20);
stream << value;
-
- emit(stream.str().c_str());
+ auto str = stream.str();
+ auto slice = UnownedStringSlice(str.c_str());
+ // Remove redundant trailing 0s.
+ if (slice.end() > slice.begin())
+ {
+ auto lastChar = slice.end() - 1;
+ while (lastChar > slice.begin() && *lastChar == '0')
+ lastChar--;
+ if (*lastChar == '.')
+ lastChar++;
+ if (lastChar > slice.end() - 1)
+ lastChar = slice.end() - 1;
+ slice = slice.subString(0, lastChar - slice.begin() + 1);
+ }
+ emit(slice);
}
void SourceWriter::advanceToSourceLocationIfValid(const SourceLoc& sourceLocation)
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 3d923179c..2ef0a5647 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -54,6 +54,7 @@
#include "slang-ir-liveness.h"
#include "slang-ir-glsl-liveness.h"
#include "slang-ir-string-hash.h"
+#include "slang-ir-simplify-for-emit.h"
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
@@ -1008,6 +1009,9 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
linkedIR));
auto irModule = linkedIR.module;
+
+ // Perform final simplifications to help emit logic to generate more compact code.
+ simplifyForEmit(irModule);
metadata = linkedIR.metadata;
@@ -1015,15 +1019,6 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
// passes have been performed, we can emit target code from
// the IR module.
//
- // TODO: do we want to emit directly from IR, or translate the
- // IR back into AST for emission?
-#if 0
- {
- StringBuilder sb;
- StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
- dumpIR(irModule, getIRDumpOptions(), sourceManager, &writer);
- }
-#endif
sourceEmitter->emitModule(irModule, sink);
}
diff --git a/source/slang/slang-intrinsic-expand.cpp b/source/slang/slang-intrinsic-expand.cpp
index 64ef4e761..de3396efb 100644
--- a/source/slang/slang-intrinsic-expand.cpp
+++ b/source/slang/slang-intrinsic-expand.cpp
@@ -735,14 +735,10 @@ const char* IntrinsicExpandContext::_emitSpecial(const char* cursor)
Index argIndex = 0;
SLANG_RELEASE_ASSERT(m_argCount > argIndex);
auto arg = m_args[argIndex].get();
- auto argLoad = as<IRLoad>(arg);
- SLANG_RELEASE_ASSERT(argLoad);
-
- auto argVar = argLoad->getOperand(0);
// Find the associated decoration
IRDecoration* foundDecoration = nullptr;
- for (auto decoration : argVar->getDecorations())
+ for (auto decoration : arg->getDecorations())
{
const auto curKind = LocationTracker::getKindFromDecoration(decoration);
if (curKind == kind)
@@ -755,7 +751,7 @@ const char* IntrinsicExpandContext::_emitSpecial(const char* cursor)
// Must have found the decoration
SLANG_ASSERT(foundDecoration);
- const auto location = m_emitter->getLocationTracker().getValue(kind, argVar, foundDecoration);
+ const auto location = m_emitter->getLocationTracker().getValue(kind, arg, foundDecoration);
m_writer->emit(location);
}
}
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 16b7a977c..a45a3abf9 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1162,6 +1162,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
inBuilder->addForwardDerivativeDecoration(origFunc, diffFunc);
}
+ inBuilder->addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
+
FuncBodyTranscriptionTask task;
task.type = FuncBodyTranscriptionTaskType::Forward;
task.originalFunc = origFunc;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index fed53b037..702f9819a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -12,6 +12,7 @@
#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-init-local-var.h"
+#include "slang-ir-redundancy-removal.h"
namespace Slang
{
@@ -305,7 +306,14 @@ namespace Slang
IRFunc* primalFunc = origFunc;
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
- differentiableTypeConformanceContext.setFunc(origFunc);
+
+ // The original func may not have a type dictionary if it is not originally marked as
+ // differentiable, in this case we would have already pulled the necessary types from
+ // the user-provided derivative function, so we are still fine.
+ if (origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ differentiableTypeConformanceContext.setFunc(origFunc);
+ }
auto diffFunc = builder.createFunc();
@@ -334,7 +342,7 @@ namespace Slang
builder.setInsertBefore(diffFunc->getFirstDecorationOrChild());
cloneInst(&cloneEnv, &builder, dictDecor);
}
-
+ builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
return InstPair(primalFunc, diffFunc);
}
@@ -588,43 +596,6 @@ namespace Slang
return result;
}
- void eliminateRedundantLoad(IRFunc* func)
- {
- for (auto block : func->getBlocks())
- {
- for (auto inst = block->getFirstInst(); inst;)
- {
- auto nextInst = inst->getNextInst();
- if (auto load = as<IRLoad>(inst))
- {
- for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst())
- {
- if (auto store = as<IRStore>(prev))
- {
- if (store->getPtr() == load->getPtr())
- {
- // If the load is preceeded by a store without any side-effect insts in-between, remove the load.
- auto value = store->getVal();
- load->replaceUsesWith(value);
- load->removeAndDeallocate();
- break;
- }
- }
- else if (as<IRCall>(prev))
- {
- break;
- }
- else if (prev->mightHaveSideEffects())
- {
- break;
- }
- }
- }
- inst = nextInst;
- }
- }
- }
-
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
// `diffPropagateFunc`.
IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
@@ -669,7 +640,7 @@ namespace Slang
primalOuterParent->removeAndDeallocate();
// Remove redundant loads since they interfere with transposition logic.
- eliminateRedundantLoad(fwdDiffFunc);
+ eliminateRedundantLoadStore(fwdDiffFunc);
// Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`.
if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc)))
@@ -976,13 +947,16 @@ namespace Slang
{
// Create dOut param.
auto diffParam = builder->emitParam(diffType);
+ copyNameHintDecoration(diffParam, fwdParam);
result.propagateFuncParams.Add(diffParam);
primalRefReplacement = builder->emitParam(builder->getOutType(primalType));
+ copyNameHintDecoration(primalRefReplacement, fwdParam);
// Create a local var for read access in pre-transpose code.
// This will the var from which we will fetch the final resulting derivative
// after transposition.
auto tempVar = nextBlockBuilder.emitVar(diffType);
+ copyNameHintDecoration(tempVar, fwdParam);
nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType);
// Initialize the var with input diff param at start.
@@ -999,11 +973,13 @@ namespace Slang
else
{
primalRefReplacement = builder->emitParam(outType);
+ copyNameHintDecoration(primalRefReplacement, fwdParam);
}
result.primalFuncParams.Add(primalRefReplacement);
// Create a local var for the out param for the primal part of the prop func.
auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType());
+ copyNameHintDecoration(tempPrimalVar, fwdParam);
result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar;
instsToRemove.Add(fwdParam);
@@ -1023,10 +999,13 @@ namespace Slang
// Create an in param for the prop func.
auto propParam = builder->emitParam(inoutType->getValueType());
+ copyNameHintDecoration(propParam, fwdParam);
result.propagateFuncParams.Add(propParam);
// Create a local var for the out param for the primal part of the prop func.
auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType());
+ copyNameHintDecoration(tempPrimalVar, fwdParam);
+
result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar);
auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam);
result.propagateFuncSpecificPrimalInsts.add(storeInst);
@@ -1054,8 +1033,11 @@ namespace Slang
// Create inout version.
auto inoutDiffPairType = builder->getInOutType(diffPairType);
primalRefReplacement = builder->emitParam(primalType);
+ copyNameHintDecoration(primalRefReplacement, fwdParam);
+
result.primalFuncParams.Add(primalRefReplacement);
auto propParam = builder->emitParam(inoutDiffPairType);
+ copyNameHintDecoration(propParam, fwdParam);
result.propagateFuncParams.Add(propParam);
// A reference to this parameter from the diff blocks should be replaced with a load
@@ -1085,9 +1067,11 @@ namespace Slang
// Process differentiable inout parameters.
auto primalParam = builder->emitParam(builder->getInOutType(primalType));
+ copyNameHintDecoration(primalParam, fwdParam);
result.primalFuncParams.Add(primalParam);
auto diffParam = builder->emitParam(inoutType);
+ copyNameHintDecoration(diffParam, fwdParam);
result.propagateFuncParams.Add(diffParam);
// Primal references to this param is the new primal param.
@@ -1102,6 +1086,7 @@ namespace Slang
// Create a local var for diff read access.
auto diffVar = nextBlockBuilder.emitVar(diffType);
+ copyNameHintDecoration(diffVar, fwdParam);
result.propagateFuncSpecificPrimalInsts.add(diffVar);
diffBuilder.markInstAsDifferential(diffVar, diffPairType);
diffRefReplacement = diffVar;
@@ -1114,6 +1099,8 @@ namespace Slang
// Create a local var for diff write access.
auto diffWriteVar = nextBlockBuilder.emitVar(diffType);
+ copyNameHintDecoration(diffWriteVar, fwdParam);
+
// Initialize write var to 0.
auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff);
result.propagateFuncSpecificPrimalInsts.add(writeStore);
@@ -1122,6 +1109,8 @@ namespace Slang
// Create a local var for the primal logic in the propagate func.
auto primalVar = nextBlockBuilder.emitVar(primalType);
+ copyNameHintDecoration(primalVar, fwdParam);
+
result.propagateFuncSpecificPrimalInsts.add(primalVar);
auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam);
result.propagateFuncSpecificPrimalInsts.add(initPrimalVal);
@@ -1213,11 +1202,13 @@ namespace Slang
SLANG_ASSERT(dOutParamType);
dOutParam = builder->emitParam(dOutParamType);
+ builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut"));
result.propagateFuncParams.Add(dOutParam);
}
// Add a parameter for intermediate val.
auto ctxParam = builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
+ builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx"));
result.primalFuncParams.Add(ctxParam);
result.propagateFuncParams.Add(ctxParam);
result.dOutParam = dOutParam;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index be20d8aa8..3e7e346d2 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -295,6 +295,7 @@ struct ExtractPrimalFuncContext
auto oldIntermediateParam = func->getLastParam();
auto outIntermediary =
builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+ oldIntermediateParam->transferDecorationsTo(outIntermediary);
primalParams.Add(outIntermediary);
oldIntermediateParam->replaceUsesWith(outIntermediary);
oldIntermediateParam->removeAndDeallocate();
@@ -473,15 +474,34 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
if (inst->getOp() == kIROp_Var)
{
// This is a var for intermediate context.
- auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType();
- auto val = builder.emitFieldExtract(
- valType,
- intermediateVar,
- structKeyDecor->getStructKey());
- auto tempVar =
- builder.emitVar(valType);
- builder.emitStore(tempVar, val);
- inst->replaceUsesWith(tempVar);
+ // Replace all loads of the var with a field extract.
+ // Other type of uses will get a temp var that stores a copy of the field.
+ while (auto use = inst->firstUse)
+ {
+ if (as<IRDecoration>(use->getUser()))
+ {
+ use->set(builder.getVoidValue());
+ continue;
+ }
+ builder.setInsertBefore(use->getUser());
+ auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType();
+ auto val = builder.emitFieldExtract(
+ valType,
+ intermediateVar,
+ structKeyDecor->getStructKey());
+ if (use->getUser()->getOp() == kIROp_Load)
+ {
+ use->getUser()->replaceUsesWith(val);
+ use->getUser()->removeAndDeallocate();
+ }
+ else
+ {
+ auto tempVar =
+ builder.emitVar(valType);
+ builder.emitStore(tempVar, val);
+ use->set(tempVar);
+ }
+ }
}
else
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index fdaff4960..f38bdfdbd 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -256,6 +256,11 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I
builder.setInsertBefore(diffType);
auto pairStructType = builder.createStructType();
+ StringBuilder nameBuilder;
+ nameBuilder << "DiffPair_";
+ getTypeNameHint(nameBuilder, origBaseType);
+ builder.addNameHintDecoration(pairStructType, nameBuilder.ToString().getUnownedSlice());
+
builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType);
builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType);
return pairStructType;
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 33e5b3cb4..337caa246 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -3,6 +3,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -16,6 +17,7 @@ struct DeadCodeEliminationContext
// `eliminateDeadCode` function.
//
IRModule* module;
+
IRDeadCodeEliminationOptions options;
// If we removed an inst, there may be still "weak references" to the inst.
@@ -129,6 +131,9 @@ struct DeadCodeEliminationContext
auto inst = workList.getLast();
workList.removeLast();
+ if (!isChildInstOf(inst, root))
+ continue;
+
// At this point we know that `inst` is live,
// and we want to start considering which other
// instructions must be live because of that
@@ -426,7 +431,6 @@ bool eliminateDeadCode(
DeadCodeEliminationContext context;
context.module = module;
context.options = options;
-
return context.processModule();
}
@@ -437,7 +441,6 @@ bool eliminateDeadCode(
DeadCodeEliminationContext context;
context.module = root->getModule();
context.options = options;
-
return context.processInst(root);
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 1cb839751..26a92a17a 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -797,6 +797,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/* Differentiable Type Dictionary */
INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT)
+ /// Overrides the floating mode for the target function
+ INST(FloatingPointModeOverrideDecoration, FloatingPointModeOverride, 1, 0)
+
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index d1374477f..fad20e900 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -846,6 +846,13 @@ struct IRDifferentiableTypeDictionaryDecoration : IRDecoration
IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration)
};
+struct IRFloatingModeOverrideDecoration : IRDecoration
+{
+ IR_LEAF_ISA(FloatingPointModeOverrideDecoration)
+
+ FloatingPointMode getFloatingPointMode() { return (FloatingPointMode)cast<IRIntLit>(getOperand(0))->getValue(); }
+};
+
// An instruction that specializes another IR value
// (representing a generic) to a particular set of generic arguments
// (instructions representing types, witness tables, etc.)
@@ -2835,6 +2842,8 @@ public:
// Add a differentiable type entry to the appropriate dictionary.
IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness);
+ IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode);
+
IRInst* emitSpecializeInst(
IRType* type,
IRInst* genericVal,
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index fd0b4577a..65b4adcac 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -10,6 +10,7 @@ struct PeepholeContext : InstPassBase
{}
bool changed = false;
+ FloatingPointMode floatingPointMode = FloatingPointMode::Precise;
bool tryFoldElementExtractFromUpdateInst(IRInst* inst)
{
@@ -96,8 +97,158 @@ struct PeepholeContext : InstPassBase
return false;
}
+ bool isZero(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ return as<IRIntLit>(inst)->getValue() == 0;
+ case kIROp_FloatLit:
+ return as<IRFloatLit>(inst)->getValue() == 0.0;
+ case kIROp_MakeVector:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MatrixReshape:
+ case kIROp_VectorReshape:
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (!isZero(inst->getOperand(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ return isZero(inst->getOperand(0));
+ default:
+ return false;
+ }
+ }
+
+ bool isOne(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ return as<IRIntLit>(inst)->getValue() == 1;
+ case kIROp_FloatLit:
+ return as<IRFloatLit>(inst)->getValue() == 1.0;
+ case kIROp_MakeVector:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MatrixReshape:
+ case kIROp_VectorReshape:
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (!isOne(inst->getOperand(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ return isOne(inst->getOperand(0));
+ default:
+ return false;
+ }
+ }
+
+ bool tryOptimizeArithmeticInst(IRInst* inst)
+ {
+ bool allowUnsafeOptimizations =
+ (floatingPointMode == FloatingPointMode::Fast ||
+ isIntegralScalarOrCompositeType(inst->getDataType()));
+
+ auto tryReplace = [&](IRInst* replacement) -> bool
+ {
+ if (replacement->getFullType() != inst->getFullType())
+ {
+ // If the operand type is different from result type,
+ // we try to convert for some known cases.
+ if (auto vectorType = as<IRVectorType>(inst->getFullType()))
+ {
+ if (vectorType->getElementType() != replacement->getFullType())
+ return false;
+ IRBuilder builder(sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ replacement = builder.emitMakeVectorFromScalar(inst->getFullType(), replacement);
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ inst->replaceUsesWith(replacement);
+ inst->removeAndDeallocate();
+ return true;
+ };
+
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ if (isZero(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ else if (isZero(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ break;
+ case kIROp_Sub:
+ if (isZero(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ break;
+ case kIROp_Mul:
+ if (isOne(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ else if (isOne(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ else if (allowUnsafeOptimizations && isZero(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ else if (allowUnsafeOptimizations && isZero(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(1));
+ }
+ break;
+ case kIROp_Div:
+ if (allowUnsafeOptimizations && isZero(inst->getOperand(0)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ else if (isOne(inst->getOperand(1)))
+ {
+ return tryReplace(inst->getOperand(0));
+ }
+ }
+ return false;
+ }
+
void processInst(IRInst* inst)
{
+ if (as<IRGlobalValueWithCode>(inst))
+ {
+ if (auto fpModeDecor = inst->findDecoration<IRFloatingModeOverrideDecoration>())
+ floatingPointMode = fpModeDecor->getFloatingPointMode();
+ }
+
switch (inst->getOp())
{
case kIROp_GetResultError:
@@ -432,6 +583,12 @@ struct PeepholeContext : InstPassBase
}
}
break;
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ changed = tryOptimizeArithmeticInst(inst);
+ break;
default:
break;
}
@@ -443,6 +600,7 @@ struct PeepholeContext : InstPassBase
sharedBuilder->init(module);
sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
bool result = false;
+
for (;;)
{
changed = false;
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 9bd681115..176142601 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -109,6 +109,7 @@ bool removeRedundancy(IRModule* module)
if (auto func = as<IRFunc>(inst))
{
changed |= removeRedundancyInFunc(func);
+ changed |= eliminateRedundantLoadStore(func);
}
}
return changed;
@@ -126,4 +127,95 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
return context.removeRedundancyInBlock(deduplicateCtx, root);
}
+bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func)
+{
+ bool changed = false;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst;)
+ {
+ auto nextInst = inst->getNextInst();
+ if (auto load = as<IRLoad>(inst))
+ {
+ for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst())
+ {
+ if (auto store = as<IRStore>(prev))
+ {
+ if (store->getPtr() == load->getPtr())
+ {
+ // If the load is preceeded by a store without any side-effect insts in-between, remove the load.
+ auto value = store->getVal();
+ load->replaceUsesWith(value);
+ load->removeAndDeallocate();
+ changed = true;
+ break;
+ }
+ }
+
+ if (canInstHaveSideEffectAtAddress(func, prev, load->getPtr()))
+ {
+ break;
+ }
+ }
+ }
+ else if (auto store = as<IRStore>(inst))
+ {
+ // We perform a quick and conservative check:
+ // A store is redundant if it is followed by another store to the same address in
+ // the same basic block, and there are no instructions that may use any addresses
+ // related to this address.
+ bool hasAddrUse = false;
+ bool hasOverridingStore = false;
+
+ // Stores to global variables will never get removed.
+ if (!isChildInstOf(store->getPtr(), func))
+ hasAddrUse = true;
+
+ for (auto next = store->getNextInst(); next; next = next->getNextInst())
+ {
+ if (auto nextStore = as<IRStore>(next))
+ {
+ if (nextStore->getPtr() == store->getPtr())
+ {
+ hasOverridingStore = true;
+ break;
+ }
+ }
+
+ // If we see any insts that have reads or modifies the address before seeing
+ // an overriding store, don't remove the store.
+ // We can make the test more accurate by collecting all addresses related to
+ // the target address first, and only bail out if any of the related addresses
+ // are involved.
+ switch (next->getOp())
+ {
+ case kIROp_Load:
+ if (canAddressesPotentiallyAlias(func, next->getOperand(0), store->getPtr()))
+ {
+ hasAddrUse = true;
+ }
+ break;
+ default:
+ if (canInstHaveSideEffectAtAddress(func, next, store->getPtr()))
+ {
+ hasAddrUse = true;
+ }
+ break;
+ }
+ if (hasAddrUse)
+ break;
+ }
+
+ if (!hasAddrUse && hasOverridingStore)
+ {
+ store->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ inst = nextInst;
+ }
+ }
+ return changed;
+}
+
}
diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h
index 26b265e77..c2df7853e 100644
--- a/source/slang/slang-ir-redundancy-removal.h
+++ b/source/slang/slang-ir-redundancy-removal.h
@@ -8,4 +8,6 @@ namespace Slang
bool removeRedundancy(IRModule* module);
bool removeRedundancyInFunc(IRGlobalValueWithCode* func);
+
+ bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func);
}
diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp
new file mode 100644
index 000000000..5e5f61a4a
--- /dev/null
+++ b/source/slang/slang-ir-simplify-for-emit.cpp
@@ -0,0 +1,354 @@
+#include "slang-ir-simplify-for-emit.h"
+#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+struct SimplifyForEmitContext : public InstPassBase
+{
+ SimplifyForEmitContext(IRModule* inModule)
+ : InstPassBase(inModule)
+ {}
+
+ List<IRInst*> followUpWorkList;
+ HashSet<IRInst*> followUpWorkListSet;
+
+ void addToFollowUpWorkList(IRInst* inst)
+ {
+ if (followUpWorkListSet.Add(inst))
+ followUpWorkList.add(inst);
+ }
+
+ void processMakeStruct(IRInst* makeStruct)
+ {
+ auto structType = as<IRStructType>(makeStruct->getDataType());
+ if (!structType)
+ return;
+ for (auto use = makeStruct->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ auto user = use->getUser();
+ if (auto store = as<IRStore>(user))
+ {
+ IRBuilder builder(sharedBuilderStorage);
+ builder.setInsertBefore(user);
+ UInt i = 0;
+ for (auto field : structType->getFields())
+ {
+ auto fieldAddr = builder.emitFieldAddress(
+ builder.getPtrType(field->getFieldType()),
+ store->getPtr(),
+ field->getKey());
+ builder.emitStore(fieldAddr, makeStruct->getOperand(i));
+ addToFollowUpWorkList(makeStruct->getOperand(i));
+ i++;
+ }
+ store->removeAndDeallocate();
+ }
+ use = nextUse;
+ }
+ if (!makeStruct->hasUses())
+ makeStruct->removeAndDeallocate();
+ }
+
+ void processMakeArray(IRInst* makeArray)
+ {
+ auto arrayType = as<IRArrayType>(makeArray->getDataType());
+ if (!arrayType)
+ return;
+
+ for (auto use = makeArray->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ auto user = use->getUser();
+ if (auto store = as<IRStore>(user))
+ {
+ IRBuilder builder(sharedBuilderStorage);
+ builder.setInsertBefore(user);
+ for (UInt i = 0; i < makeArray->getOperandCount(); i++)
+ {
+ auto elementAddr = builder.emitElementAddress(
+ builder.getPtrType(arrayType->getElementType()),
+ store->getPtr(),
+ builder.getIntValue(builder.getIntType(), (IRIntegerValue)i));
+ builder.emitStore(elementAddr, makeArray->getOperand(i));
+ addToFollowUpWorkList(makeArray->getOperand(i));
+ }
+ store->removeAndDeallocate();
+ }
+ use = nextUse;
+ }
+ if (!makeArray->hasUses())
+ makeArray->removeAndDeallocate();
+ }
+
+ void processMakeArrayFromElement(IRInst* makeArray)
+ {
+ auto arrayType = as<IRArrayType>(makeArray->getDataType());
+ if (!arrayType)
+ return;
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize)
+ return;
+
+ for (auto use = makeArray->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ auto user = use->getUser();
+ if (auto store = as<IRStore>(user))
+ {
+ IRBuilder builder(sharedBuilderStorage);
+ builder.setInsertBefore(user);
+ for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
+ {
+ auto elementAddr = builder.emitElementAddress(
+ builder.getPtrType(arrayType->getElementType()),
+ store->getPtr(),
+ builder.getIntValue(builder.getIntType(), i));
+ builder.emitStore(elementAddr, makeArray->getOperand(0));
+ addToFollowUpWorkList(makeArray->getOperand(0));
+ }
+ store->removeAndDeallocate();
+ }
+ use = nextUse;
+ }
+ if (!makeArray->hasUses())
+ makeArray->removeAndDeallocate();
+ }
+
+ void processLoadUse(IRGlobalValueWithCode* func, IRLoad* load, IRUse* use)
+ {
+ auto user = use->getUser();
+ if (user->getParent() != load->getParent())
+ return;
+ for (auto inst = load->getNextInst(); inst; inst = inst->getNextInst())
+ {
+ if (inst == user)
+ break;
+ if (canInstHaveSideEffectAtAddress(func, inst, load->getPtr()))
+ return;
+ }
+
+ // If we reach here, it is OK to defer the load at use site.
+ IRBuilder builder(sharedBuilderStorage);
+ builder.setInsertBefore(user);
+ auto newLoad = builder.emitLoad(load->getPtr());
+ use->set(newLoad);
+ }
+
+ void processLoad(IRLoad* inst)
+ {
+ auto func = getParentFunc(inst);
+ if (!func)
+ return;
+
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+ processLoadUse(func, inst, use);
+ use = nextUse;
+ }
+
+ if (!inst->hasUses())
+ inst->removeAndDeallocate();
+ }
+
+ void processElementExtract(IRInst* inst)
+ {
+ // Create a duplicate for each use site.
+ // This is safe because the result value of this inst should never
+ // change regardless of where the inst is defined.
+ // By creating the duplicates right before use sites, we will enable
+ // the emit logic to always fold these insts.
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ auto user = use->getUser();
+ if (user->getPrevInst() == inst)
+ {
+ use = nextUse;
+ continue;
+ }
+
+ IRBuilder builder(sharedBuilderStorage);
+ builder.setInsertBefore(user);
+ List<IRInst*> args;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ args.add(inst->getOperand(i));
+ auto newInst = builder.emitIntrinsicInst(inst->getFullType(), inst->getOp(), inst->getOperandCount(), args.getBuffer());
+ use->set(newInst);
+
+ use = nextUse;
+ }
+ if (!inst->hasUses())
+ inst->removeAndDeallocate();
+ }
+
+ void processVar(IRInst* var)
+ {
+ // Defer var to its first use, if the use is in the same basic block as the var.
+ HashSet<IRInst*> userInSameBlock;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ if (use->getUser()->getParent() == var->getParent())
+ {
+ userInSameBlock.Add(use->getUser());
+ }
+ IRInst* firstUser = nullptr;
+ for (auto inst = var->getNextInst(); inst; inst = inst->getNextInst())
+ {
+ if (userInSameBlock.Contains(inst))
+ {
+ firstUser = inst;
+ break;
+ }
+ }
+ if (!firstUser)
+ return;
+ var->insertBefore(firstUser);
+ }
+
+ void processInst(IRInst* inst)
+ {
+ // We inspect each inst and see if the following simplifications
+ // can be applied:
+ // 1. If we see `store(addr, MakeArray/Struct)`, we should turn them
+ // into direct stores into each element/field and remove the need
+ // to create a temporary for the `MakeArray/Struct` inst.
+ // 2. If we see `load(addr)`, we duplicate the load right at each
+ // use site if it can be determined safe to do so. This allows
+ // emit logic to skip producing a temp var for the loaded result.
+ switch (inst->getOp())
+ {
+ case kIROp_MakeStruct:
+ processMakeStruct(inst);
+ break;
+ case kIROp_MakeArray:
+ processMakeArray(inst);
+ break;
+ case kIROp_MakeArrayFromElement:
+ processMakeArrayFromElement(inst);
+ break;
+ case kIROp_Load:
+ processLoad(as<IRLoad>(inst));
+ break;
+ case kIROp_GetElement:
+ case kIROp_FieldExtract:
+ processElementExtract(inst);
+ break;
+ case kIROp_Var:
+ processVar(inst);
+ break;
+ }
+ }
+
+ void eliminateCompositeConstruct(IRGlobalValueWithCode* func)
+ {
+ followUpWorkList.clear();
+ followUpWorkListSet.Clear();
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ addToFollowUpWorkList(inst);
+ break;
+ }
+ }
+ }
+ for (Index i = 0; i < followUpWorkList.getCount(); i++)
+ processInst(followUpWorkList[i]);
+ }
+
+ void deferAndDuplicateLoad(IRGlobalValueWithCode* func)
+ {
+ followUpWorkList.clear();
+ followUpWorkListSet.Clear();
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Load:
+ addToFollowUpWorkList(inst);
+ break;
+ }
+ }
+ }
+ for (Index i = 0; i < followUpWorkList.getCount(); i++)
+ processInst(followUpWorkList[i]);
+ }
+
+ void deferVarDecl(IRGlobalValueWithCode* func)
+ {
+ followUpWorkList.clear();
+ followUpWorkListSet.Clear();
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ addToFollowUpWorkList(inst);
+ break;
+ }
+ }
+ }
+ for (Index i = 0; i < followUpWorkList.getCount(); i++)
+ processInst(followUpWorkList[i]);
+ }
+
+ void deferAndDuplicateElementExtract(IRGlobalValueWithCode* func)
+ {
+ followUpWorkList.clear();
+ followUpWorkListSet.Clear();
+
+ for (auto block = func->getLastBlock(); block; block = block->getPrevBlock())
+ {
+ for (auto inst = block->getLastChild(); inst; inst = inst->getPrevInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_GetElement:
+ case kIROp_FieldExtract:
+ addToFollowUpWorkList(inst);
+ break;
+ }
+ }
+ }
+ for (Index i = 0; i < followUpWorkList.getCount(); i++)
+ processInst(followUpWorkList[i]);
+ }
+
+ void processFunc(IRGlobalValueWithCode* func)
+ {
+ eliminateCompositeConstruct(func);
+ deferAndDuplicateElementExtract(func);
+ deferAndDuplicateLoad(func);
+ deferVarDecl(func);
+ }
+
+ void processModule()
+ {
+ sharedBuilderStorage.init(module);
+ processInstsOfType<IRFunc>(kIROp_Func, [this](IRFunc* f) { processFunc(f); });
+ }
+};
+
+void simplifyForEmit(IRModule* module)
+{
+ SimplifyForEmitContext context(module);
+ context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-simplify-for-emit.h b/source/slang/slang-ir-simplify-for-emit.h
new file mode 100644
index 000000000..a6cf3bad8
--- /dev/null
+++ b/source/slang/slang-ir-simplify-for-emit.h
@@ -0,0 +1,9 @@
+// slang-ir-simplify-for-emit.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ void simplifyForEmit(IRModule* inModule);
+}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 5cf074484..af6fd8ac4 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -229,6 +229,217 @@ String dumpIRToString(IRInst* root)
return sb.ToString();
}
+void copyNameHintDecoration(IRInst* dest, IRInst* src)
+{
+ auto decor = src->findDecoration<IRNameHintDecoration>();
+ if (decor)
+ {
+ cloneDecoration(decor, dest);
+ }
+}
+
+void getTypeNameHint(StringBuilder& sb, IRInst* type)
+{
+ if (!type)
+ return;
+
+ switch (type->getOp())
+ {
+ case kIROp_FloatType:
+ sb << "float";
+ break;
+ case kIROp_HalfType:
+ sb << "half";
+ break;
+ case kIROp_DoubleType:
+ sb << "double";
+ break;
+ case kIROp_IntType:
+ sb << "int";
+ break;
+ case kIROp_Int8Type:
+ sb << "int8";
+ break;
+ case kIROp_Int16Type:
+ sb << "int16";
+ break;
+ case kIROp_Int64Type:
+ sb << "int64";
+ break;
+ case kIROp_IntPtrType:
+ sb << "intptr";
+ break;
+ case kIROp_UIntType:
+ sb << "uint";
+ break;
+ case kIROp_UInt8Type:
+ sb << "uint8";
+ break;
+ case kIROp_UInt16Type:
+ sb << "uint16";
+ break;
+ case kIROp_UInt64Type:
+ sb << "uint64";
+ break;
+ case kIROp_UIntPtrType:
+ sb << "uintptr";
+ break;
+ case kIROp_CharType:
+ sb << "char";
+ break;
+ case kIROp_StringType:
+ sb << "string";
+ break;
+ case kIROp_ArrayType:
+ sb << "array_";
+ getTypeNameHint(sb, type->getOperand(0));
+ break;
+ case kIROp_VectorType:
+ getTypeNameHint(sb, type->getOperand(0));
+ getTypeNameHint(sb, as<IRVectorType>(type)->getElementCount());
+ break;
+ case kIROp_MatrixType:
+ getTypeNameHint(sb, type->getOperand(0));
+ getTypeNameHint(sb, as<IRMatrixType>(type)->getRowCount());
+ sb << "x";
+ getTypeNameHint(sb, as<IRMatrixType>(type)->getColumnCount());
+ break;
+ case kIROp_IntLit:
+ sb << as<IRIntLit>(type)->getValue();
+ break;
+ default:
+ if (auto decor = type->findDecoration<IRNameHintDecoration>())
+ sb << decor->getName();
+ break;
+ }
+}
+
+static IRInst* _getRootAddr(IRInst* addr)
+{
+ for (;;)
+ {
+ switch (addr->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ addr = addr->getOperand(0);
+ continue;
+ default:
+ break;
+ }
+ break;
+ }
+ return addr;
+}
+
+// A simple and conservative address aliasing check.
+bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IRInst* addr2)
+{
+ if (addr1 == addr2)
+ return true;
+
+ // Two variables can never alias.
+ addr1 = _getRootAddr(addr1);
+ addr2 = _getRootAddr(addr2);
+
+ // Global addresses can alias with anything.
+ if (!isChildInstOf(addr1, func))
+ return true;
+
+ if (!isChildInstOf(addr2, func))
+ return true;
+
+ if (addr1->getOp() == kIROp_Var && addr2->getOp() == kIROp_Var
+ && addr1 != addr2)
+ return false;
+
+ // A param and a var can never alias.
+ if (addr1->getOp() == kIROp_Param && addr1->getParent() == func->getFirstBlock() &&
+ addr2->getOp() == kIROp_Var ||
+ addr1->getOp() == kIROp_Var && addr2->getOp() == kIROp_Param &&
+ addr2->getParent() == func->getFirstBlock())
+ return false;
+ return true;
+}
+
+bool isPtrLikeOrHandleType(IRInst* type)
+{
+ switch (type->getOp())
+ {
+ case kIROp_ComPtrType:
+ case kIROp_RawPointerType:
+ case kIROp_RTTIPointerType:
+ case kIROp_PseudoPtrType:
+ case kIROp_OutType:
+ case kIROp_InOutType:
+ case kIROp_PtrType:
+ case kIROp_RefType:
+ return true;
+ }
+ return false;
+}
+
+bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, IRInst* addr)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_Store:
+ // If the target of the store inst may overlap addr, return true.
+ if (canAddressesPotentiallyAlias(func, as<IRStore>(inst)->getPtr(), addr))
+ return true;
+ break;
+ case kIROp_Call:
+ {
+ auto call = as<IRCall>(inst);
+
+ // If addr is a global variable, calling a function may change its value.
+ // So we need to return true here to be conservative.
+ if (!isChildInstOf(_getRootAddr(addr), func))
+ {
+ auto callee = call->getCallee();
+ if (callee &&
+ callee->findDecoration<IRReadNoneDecoration>() &&
+ callee->findDecoration<IRNoSideEffectDecoration>())
+ {
+ // An exception is if the callee is side-effect free and is not reading from
+ // memory.
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+ // If any pointer typed argument of the call inst may overlap addr, return true.
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ if (isPtrLikeOrHandleType(call->getArg(i)->getDataType()))
+ {
+ if (canAddressesPotentiallyAlias(func, call->getArg(i), addr))
+ return true;
+ }
+ }
+ }
+ break;
+ case kIROp_CastPtrToInt:
+ case kIROp_Reinterpret:
+ case kIROp_BitCast:
+ {
+ // If we are trying to cast an address to something else, return true.
+ if (isPtrLikeOrHandleType(inst->getOperand(0)->getDataType()) &&
+ canAddressesPotentiallyAlias(func, inst->getOperand(0), addr))
+ return true;
+ }
+ break;
+ default:
+ // Default behavior is that any insts that have side effect may affect `addr`.
+ if (inst->mightHaveSideEffects())
+ return true;
+ break;
+ }
+ return false;
+}
+
bool isPureFunctionalCall(IRCall* call)
{
auto callee = getResolvedInstForDecorations(call->getCallee());
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 26ad4bc68..c067bde44 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -153,11 +153,20 @@ inline IRInst* unwrapAttributedType(IRInst* type)
return type;
}
+void getTypeNameHint(StringBuilder& sb, IRInst* type);
+void copyNameHintDecoration(IRInst* dest, IRInst* src);
+
+bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IRInst* addr2);
+
String dumpIRToString(IRInst* root);
// Returns whether a call insts can be treated as a pure functional inst
// (no writes to memory, no side effects).
bool isPureFunctionalCall(IRCall* callInst);
+
+bool isPtrLikeOrHandleType(IRInst* type);
+
+bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, IRInst* addr);
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 558574bf6..1b16bfe1f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4604,6 +4604,14 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode)
+ {
+ return addDecoration(
+ dest,
+ kIROp_FloatingPointModeOverrideDecoration,
+ getIntValue(getIntType(), (IRIntegerValue)mode));
+ }
+
IRInst* IRBuilder::emitSwizzle(
IRType* type,
IRInst* base,
@@ -6418,6 +6426,20 @@ namespace Slang
return false;
}
+ bool isIntegralScalarOrCompositeType(IRType* t)
+ {
+ if (!t)
+ return false;
+ switch (t->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return isIntegralType((IRType*)t->getOperand(0));
+ default:
+ return isIntegralType(t);
+ }
+ }
+
void findAllInstsBreadthFirst(IRInst* inst, List<IRInst*>& outInsts)
{
Index index = outInsts.getCount();
@@ -6577,6 +6599,8 @@ namespace Slang
void IRInst::insertBefore(IRInst* other)
{
SLANG_ASSERT(other);
+ if (other->getPrevInst() == this)
+ return;
_insertAt(other->getPrevInst(), other, other->getParent());
}
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index e22e41f0c..41b140972 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -869,6 +869,8 @@ bool isTypeEqual(IRType* a, IRType* b);
// True if this is an integral IRBasicType, not including Char or Ptr types
bool isIntegralType(IRType* t);
+bool isIntegralScalarOrCompositeType(IRType* t);
+
void findAllInstsBreadthFirst(IRInst* inst, List<IRInst*>& outInsts);
// Constant Instructions
@@ -943,6 +945,13 @@ struct IRIntLit : IRConstant
IR_LEAF_ISA(IntLit);
};
+struct IRFloatLit : IRConstant
+{
+ IRFloatingPointValue getValue() { return value.floatVal; }
+
+ IR_LEAF_ISA(FloatLit);
+};
+
struct IRBoolLit : IRConstant
{
bool getValue() { return value.intVal != 0; }
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 149f5f6b9..8377246fb 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3304,6 +3304,15 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
}
+ LoweredValInfo visitMakeRefExpr(MakeRefExpr* expr)
+ {
+ auto loweredBase = lowerLValueExpr(context, expr->base);
+
+ SLANG_ASSERT(loweredBase.flavor == LoweredValInfo::Flavor::Ptr);
+ loweredBase.flavor = LoweredValInfo::Flavor::Simple;
+ return loweredBase;
+ }
+
LoweredValInfo visitParenExpr(ParenExpr* expr)
{
return lowerSubExpr(expr->base);
@@ -4234,11 +4243,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
switch (baseVal.flavor)
{
case LoweredValInfo::Flavor::Simple:
- return LoweredValInfo::simple(
- builder->emitElementExtract(
- type,
- getSimpleVal(context, baseVal),
- indexVal));
+ return LoweredValInfo::simple(
+ builder->emitElementExtract(
+ type,
+ getSimpleVal(context, baseVal),
+ indexVal));
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
@@ -4416,7 +4425,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitOpenRefExpr(OpenRefExpr* expr)
{
- return lowerLValueExpr(context, expr->innerExpr);
+ auto info = lowerRValueExpr(context, expr->innerExpr);
+ SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(info.val->getFullType()));
+ SLANG_RELEASE_ASSERT(info.flavor == LoweredValInfo::Flavor::Simple);
+ info.flavor = LoweredValInfo::Flavor::Ptr;
+ return info;
}
};
@@ -4596,10 +4609,6 @@ LoweredValInfo lowerLValueExpr(
LValueExprLoweringVisitor visitor;
visitor.context = context;
auto info = visitor.dispatch(expr);
- if (as<RefType>(expr->type))
- {
- info.flavor = LoweredValInfo::Flavor::Ptr;
- }
return info;
}
@@ -4612,10 +4621,6 @@ LoweredValInfo lowerRValueExpr(
RValueExprLoweringVisitor visitor;
visitor.context = context;
auto info = visitor.dispatch(expr);
- if (as<RefType>(expr->type))
- {
- info.val = context->irBuilder->emitLoad(info.val);
- }
return info;
}