summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-09-01 10:01:13 -0700
committerGitHub <noreply@github.com>2022-09-01 10:01:13 -0700
commit4a94473eb34376dd8474f8ca3f2834b5c1daac14 (patch)
tree218714e897a2821c2b09727590f364519afe3915
parent3c0177134d126956336865623ea3d6861be59cfa (diff)
Deduplicate consts and IRSpecialize in IR, propagate type info for `IntVal`. (#2388)
-rw-r--r--source/slang/slang-ast-val.cpp30
-rw-r--r--source/slang/slang-ast-val.h26
-rw-r--r--source/slang/slang-check-conversion.cpp2
-rw-r--r--source/slang/slang-check-decl.cpp5
-rw-r--r--source/slang/slang-check-expr.cpp17
-rw-r--r--source/slang/slang-check-shader.cpp2
-rw-r--r--source/slang/slang-ir-deduplicate.cpp11
-rw-r--r--source/slang/slang-ir-specialize.cpp1
-rw-r--r--source/slang/slang-ir.cpp3
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
-rw-r--r--tests/bugs/generic-type-duplication.slang33
-rw-r--r--tests/bugs/generic-type-duplication.slang.expected.txt4
12 files changed, 103 insertions, 41 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index dd5aff238..64d32f4b4 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -722,10 +722,10 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
*ioDiff += diff;
if (evaluatedTerms.getCount() == 0)
- return astBuilder->create<ConstantIntVal>(evaluatedConstantTerm);
+ return astBuilder->create<ConstantIntVal>(type, evaluatedConstantTerm);
if (diff != 0)
{
- auto newPolynomial = astBuilder->create<PolynomialIntVal>();
+ auto newPolynomial = astBuilder->create<PolynomialIntVal>(type);
newPolynomial->constantTerm = evaluatedConstantTerm;
newPolynomial->terms = _Move(evaluatedTerms);
return newPolynomial->canonicalize(astBuilder);
@@ -770,7 +770,7 @@ bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal*
PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(base->type);
if (!addToPolynomialTerm(astBuilder, result, base, -1))
return nullptr;
result->canonicalize(astBuilder);
@@ -779,7 +779,7 @@ PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(op0->type);
if (!addToPolynomialTerm(astBuilder, result, op0, 1))
return nullptr;
if (!addToPolynomialTerm(astBuilder, result, op1, -1))
@@ -790,7 +790,7 @@ PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, Int
PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(op0->type);
if (!addToPolynomialTerm(astBuilder, result, op0, 1))
return nullptr;
if (!addToPolynomialTerm(astBuilder, result, op1, 1))
@@ -805,7 +805,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
{
if (auto poly1 = as<PolynomialIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
// add poly0.constant * poly1.constant
result->constantTerm = poly0->constantTerm * poly1->constantTerm;
// add poly0.constant * poly1.terms
@@ -847,7 +847,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
}
else if (auto cVal1 = as<ConstantIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
result->constantTerm = poly0->constantTerm * cVal1->value;
auto factor1 = astBuilder->create<PolynomialIntValFactor>();
for (auto term : poly0->terms)
@@ -863,7 +863,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
}
else if (auto val1 = as<IntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(poly0->type);
result->constantTerm = 0;
auto factor1 = astBuilder->create<PolynomialIntValFactor>();
factor1->power = 1;
@@ -901,7 +901,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
}
else if (auto cVal1 = as<ConstantIntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(val0->type);
auto term = astBuilder->create<PolynomialIntValTerm>();
term->constFactor = cVal1->value;
auto factor0 = astBuilder->create<PolynomialIntValFactor>();
@@ -914,7 +914,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
}
else if (auto val1 = as<IntVal>(op1))
{
- auto result = astBuilder->create<PolynomialIntVal>();
+ auto result = astBuilder->create<PolynomialIntVal>(val0->type);
auto term = astBuilder->create<PolynomialIntValTerm>();
term->constFactor = 1;
auto factor0 = astBuilder->create<PolynomialIntValFactor>();
@@ -1035,7 +1035,7 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
return terms[0]->paramFactors[0]->param;
}
if (terms.getCount() == 0)
- return builder->create<ConstantIntVal>(constantTerm);
+ return builder->create<ConstantIntVal>(type, constantTerm);
return this;
}
@@ -1127,7 +1127,7 @@ static bool nameIs(Name* name, const char* val)
return false;
}
-Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
+Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
{
// Are all args const now?
List<ConstantIntVal*> constArgs;
@@ -1207,7 +1207,7 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDe
{
SLANG_UNREACHABLE("constant folding of FuncCallIntVal");
}
- return astBuilder->create<ConstantIntVal>(resultValue);
+ return astBuilder->create<ConstantIntVal>(resultType, resultValue);
}
return nullptr;
}
@@ -1228,12 +1228,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
if (diff)
{
// TODO: report diagnostics back.
- auto newVal = tryFoldImpl(astBuilder, newFuncDeclRef, newArgs, nullptr);
+ auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr);
if (newVal)
return newVal;
else
{
- auto result = astBuilder->create<FuncCallIntVal>();
+ auto result = astBuilder->create<FuncCallIntVal>(type);
result->args = _Move(newArgs);
result->funcDeclRef = newFuncDeclRef;
result->funcType = funcType;
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 259be75c3..69797b3a5 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -12,6 +12,13 @@ namespace Slang {
class IntVal : public Val
{
SLANG_ABSTRACT_AST_CLASS(IntVal)
+
+ Type* type;
+
+ IntVal(Type* inType)
+ : type(inType)
+ {}
+
};
// Trivial case of a value that is just a constant integer
@@ -27,8 +34,8 @@ class ConstantIntVal : public IntVal
HashCode _getHashCodeOverride();
protected:
- ConstantIntVal(IntegerLiteralValue inValue)
- : value(inValue)
+ ConstantIntVal(Type* inType, IntegerLiteralValue inValue)
+ : IntVal(inType), value(inValue)
{}
};
@@ -47,8 +54,8 @@ class GenericParamIntVal : public IntVal
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
protected:
- GenericParamIntVal(DeclRef<VarDeclBase> inDeclRef)
- : declRef(inDeclRef)
+ GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
+ : IntVal(inType), declRef(inDeclRef)
{}
};
@@ -66,7 +73,9 @@ class FuncCallIntVal : public IntVal
Type* funcType;
List<IntVal*> args;
- static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink);
+ FuncCallIntVal(Type* inType) : IntVal(inType) {}
+
+ static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink);
};
class WitnessLookupIntVal : public IntVal
@@ -80,7 +89,8 @@ class WitnessLookupIntVal : public IntVal
SubtypeWitness* witness;
Decl* key;
- Type* type;
+
+ WitnessLookupIntVal(Type* inType) : IntVal(inType) {}
static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key);
@@ -140,6 +150,7 @@ public:
{
return power == other.power && param->equalsVal(other.param);
}
+
};
class PolynomialIntValTerm : public NodeBase
{
@@ -207,6 +218,7 @@ public:
static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
+ PolynomialIntVal(Type* inType) : IntVal(inType) {}
};
@@ -215,6 +227,8 @@ class ErrorIntVal : public IntVal
{
SLANG_AST_CLASS(ErrorIntVal)
+ ErrorIntVal(Type* inType) : IntVal(inType) {}
+
// TODO: We should probably eventually just have an `ErrorVal` here
// and have all `Val`s that represent ordinary values hold their
// `Type` so that we can have an `ErrorVal` of any type.
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index a3273f942..4efacd703 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -349,7 +349,7 @@ namespace Slang
// We have a new type for the conversion, based on what
// we learned.
toType = m_astBuilder->getArrayType(toElementType,
- m_astBuilder->create<ConstantIntVal>(elementCount));
+ m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount));
}
}
else if(auto toMatrixType = as<MatrixExpressionType>(toType))
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 9e3fbaa8d..9f19023ee 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -536,7 +536,9 @@ namespace Slang
}
else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) )
{
- genericSubst->args.add(astBuilder->create<GenericParamIntVal>(DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst)));
+ genericSubst->args.add(astBuilder->create<GenericParamIntVal>(
+ genericValueParamDecl->getType(),
+ DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst)));
}
}
@@ -4235,6 +4237,7 @@ namespace Slang
else if (auto valueParam = as<GenericValueParamDecl>(dd))
{
auto val = m_astBuilder->create<GenericParamIntVal>(
+ valueParam->getType(),
makeDeclRef(valueParam));
subst->args.add(val);
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index e33d26c0c..c7bfdd3a6 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -866,7 +866,7 @@ namespace Slang
IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr)
{
// TODO(tfoley): don't keep allocating here!
- return m_astBuilder->create<ConstantIntVal>(expr->value);
+ return m_astBuilder->create<ConstantIntVal>(expr->type.type, expr->value);
}
IntVal* SemanticsVisitor::tryConstantFoldExpr(
@@ -982,7 +982,7 @@ namespace Slang
|| opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") ||
opName == getName("?:") || opName == getName("<<") || opName == getName(">>"))
{
- auto result = m_astBuilder->create<FuncCallIntVal>();
+ auto result = m_astBuilder->create<FuncCallIntVal>(invokeExpr.getExpr()->type.type);
result->args.addRange(argVals, argCount);
result->funcDeclRef = funcDeclRef;
result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
@@ -1091,7 +1091,7 @@ namespace Slang
}
}
- IntVal* result = m_astBuilder->create<ConstantIntVal>(resultValue);
+ IntVal* result = m_astBuilder->create<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue);
return result;
}
@@ -1176,7 +1176,7 @@ namespace Slang
{
// If it's a boolean, we allow promotion to int.
const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value);
- return m_astBuilder->create<ConstantIntVal>(value);
+ return m_astBuilder->create<ConstantIntVal>(m_astBuilder->getBoolType(), value);
}
// it is possible that we are referring to a generic value param
@@ -1186,8 +1186,9 @@ namespace Slang
if (auto genericValParamRef = declRef.as<GenericValueParamDecl>())
{
- // TODO(tfoley): handle the case of non-`int` value parameters...
- Val* valResult = m_astBuilder->create<GenericParamIntVal>(genericValParamRef);
+ Val* valResult = m_astBuilder->create<GenericParamIntVal>(
+ declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()),
+ genericValParamRef);
valResult = valResult->substitute(m_astBuilder, expr.getSubsts());
return as<IntVal>(valResult);
}
@@ -2145,7 +2146,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->create<ConstantIntVal>(elementCount)));
+ m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
@@ -2266,7 +2267,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->create<ConstantIntVal>(elementCount)));
+ m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)));
}
// A swizzle can be used as an l-value as long as there
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 184ff3350..2e577b6d5 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -1120,7 +1120,7 @@ namespace Slang
if(!intVal)
{
sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl);
- intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(0);
+ intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(m_astBuilder->getIntType(), 0);
}
ModuleSpecializationInfo::GenericArgInfo expandedArg;
diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp
index 8aef7736c..51a677627 100644
--- a/source/slang/slang-ir-deduplicate.cpp
+++ b/source/slang/slang-ir-deduplicate.cpp
@@ -53,17 +53,22 @@ namespace Slang
context.builder = this;
m_constantMap.Clear();
m_globalValueNumberingMap.Clear();
+ List<IRInst*> instToRemove;
for (auto inst : m_module->getGlobalInsts())
{
if (auto constVal = as<IRConstant>(inst))
{
- context.addConstantValue(constVal);
+ auto newConst = context.addConstantValue(constVal);
+ if (newConst != constVal)
+ {
+ constVal->replaceUsesWith(newConst);
+ instToRemove.add(constVal);
+ }
}
}
- List<IRInst*> instToRemove;
for (auto inst : m_module->getGlobalInsts())
{
- if (as<IRType>(inst))
+ if (as<IRType>(inst) || as<IRSpecialize>(inst))
{
auto newInst = context.addTypeValue(inst);
if (newInst != inst)
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 1dd18ea47..b9a8f68ab 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -731,6 +731,7 @@ struct SpecializationContext
// This prevents us from generating duplicated specializations
// when this pass is invoked iteratively.
readSpecializationDictionaries();
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
// The unspecialized IR we receive as input will have
// `IRBindGlobalGenericParam` instructions that associate
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 20bf6060d..19d3ce59a 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2146,6 +2146,9 @@ namespace Slang
keyInst.value.intVal = static_cast<uint16_t>(inValue);
break;
case kIROp_BoolType:
+ keyInst.m_op = kIROp_BoolLit;
+ keyInst.value.intVal = ((inValue != 0) ? 1 : 0);
+ break;
case kIROp_UIntType:
keyInst.value.intVal = static_cast<uint32_t>(inValue);
break;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 0921d36fb..d9080169a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1321,11 +1321,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val)
{
auto irBuilder = getBuilder();
- auto constTerm = irBuilder->getIntValue(irBuilder->getIntType(), val->constantTerm);
+ auto type = lowerType(context, val->type);
+ auto constTerm = irBuilder->getIntValue(type, val->constantTerm);
auto resultVal = constTerm;
for (auto term : val->terms)
{
- auto termVal = irBuilder->getIntValue(irBuilder->getIntType(), term->constFactor);
+ auto termVal = irBuilder->getIntValue(type, term->constFactor);
for (auto factor : term->paramFactors)
{
auto factorVal = lowerVal(context, factor->param).val;
@@ -1701,10 +1702,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
{
- // TODO: it is a bit messy here that the `ConstantIntVal` representation
- // has no notion of a *type* associated with the value...
-
- auto type = getIntType(context);
+ auto type = lowerType(context, val->type);
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value));
}
diff --git a/tests/bugs/generic-type-duplication.slang b/tests/bugs/generic-type-duplication.slang
new file mode 100644
index 000000000..4117a7f81
--- /dev/null
+++ b/tests/bugs/generic-type-duplication.slang
@@ -0,0 +1,33 @@
+// Test that the same generic type specialization does not get emitted as different types in target code.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+//TEST(compute,vulkan):COMPARE_COMPUTE_EX:-vk -slang -compute -shaderobj
+
+struct MyGeneric<let addOne: bool>
+{
+ int value;
+
+ [mutating]
+ void load(RWStructuredBuffer<MyGeneric<addOne>> buffer)
+ {
+ var m = buffer.Load(0);
+ if (addOne)
+ value = m.value + 1;
+ else
+ value = m.value;
+ }
+};
+
+//TEST_INPUT:set myBuffer = ubuffer(data=[1],stride=4)
+RWStructuredBuffer<MyGeneric<true>> myBuffer;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ MyGeneric<true> obj;
+ obj.load(myBuffer);
+ outputBuffer[dispatchThreadID.x] = obj.value;
+}
diff --git a/tests/bugs/generic-type-duplication.slang.expected.txt b/tests/bugs/generic-type-duplication.slang.expected.txt
new file mode 100644
index 000000000..487b11653
--- /dev/null
+++ b/tests/bugs/generic-type-duplication.slang.expected.txt
@@ -0,0 +1,4 @@
+2
+2
+2
+2