1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-insts.h"
namespace Slang
{
void legalizeBinaryOp(IRInst* inst)
{
// For shifts, ensure that the shift amount is unsigned, as required by
// https://www.w3.org/TR/WGSL/#bit-expr.
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
{
IRInst* shiftAmount = inst->getOperand(1);
IRType* shiftAmountType = shiftAmount->getDataType();
if (auto shiftAmountVectorType = as<IRVectorType>(shiftAmountType))
{
IRType* shiftAmountElementType = shiftAmountVectorType->getElementType();
IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType);
if (opIntInfo.isSigned)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
opIntInfo.isSigned = false;
shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
shiftAmountVectorType = builder.getVectorType(
shiftAmountElementType,
shiftAmountVectorType->getElementCount());
IRInst* newShiftAmount = builder.emitCast(shiftAmountVectorType, shiftAmount);
builder.replaceOperand(inst->getOperands() + 1, newShiftAmount);
}
}
else if (isIntegralType(shiftAmountType))
{
IntInfo opIntInfo = getIntTypeInfo(shiftAmountType);
if (opIntInfo.isSigned)
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
opIntInfo.isSigned = false;
shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount);
builder.replaceOperand(inst->getOperands() + 1, newShiftAmount);
}
}
}
auto isVectorOrMatrix = [](IRType* type)
{
switch (type->getOp())
{
case kIROp_VectorType:
case kIROp_MatrixType:
return true;
default:
return false;
}
};
if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
as<IRBasicType>(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
IRType* compositeType = inst->getOperand(0)->getDataType();
IRInst* scalarValue = inst->getOperand(1);
// Retain the scalar type for shifts
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
{
auto vectorType = as<IRVectorType>(compositeType);
compositeType =
builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
}
auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
builder.replaceOperand(inst->getOperands() + 1, newRhs);
}
else if (
as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
isVectorOrMatrix(inst->getOperand(1)->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
IRType* compositeType = inst->getOperand(1)->getDataType();
IRInst* scalarValue = inst->getOperand(0);
// Retain the scalar type for shifts
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
{
auto vectorType = as<IRVectorType>(compositeType);
compositeType =
builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
}
auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
builder.replaceOperand(inst->getOperands(), newLhs);
}
else if (
isIntegralType(inst->getOperand(0)->getDataType()) &&
isIntegralType(inst->getOperand(1)->getDataType()))
{
// Unless the operator is a shift, and if the integer operands differ in signedness,
// then convert the signed one to unsigned.
// We're assuming that the cases where this is bad have already been caught by
// common validation checks.
IntInfo opIntInfo[2] = {
getIntTypeInfo(inst->getOperand(0)->getDataType()),
getIntTypeInfo(inst->getOperand(1)->getDataType())};
bool isShift = inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh;
bool signednessDiffers = opIntInfo[0].isSigned != opIntInfo[1].isSigned;
if (!isShift && signednessDiffers)
{
int signedOpIndex = (int)opIntInfo[1].isSigned;
opIntInfo[signedOpIndex].isSigned = false;
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newOp = builder.emitCast(
builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])),
inst->getOperand(signedOpIndex));
builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp);
}
}
}
} // namespace Slang
|