Skip to content

Commit

Permalink
Assert correct values in APInt constructor
Browse files Browse the repository at this point in the history
If the uint64_t constructor is used, assert that the value is
actuall a signed or unsigned N-bit integer depending on whether
the isSigned flag is set.

Currently, we allow values to be silently truncated, which is
a constant source of subtle bugs -- a particularly common mistake
is to create -1 values without setting the isSigned flag, which
will work fine for all common bit widths (<= 64-bit) and miscompile
for larger integers.
  • Loading branch information
nikic authored and akiramenai committed Aug 26, 2024
1 parent b87daf6 commit f9f732c
Show file tree
Hide file tree
Showing 25 changed files with 391 additions and 343 deletions.
4 changes: 3 additions & 1 deletion llvm/include/llvm/ADT/APFixedPoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ class APFixedPoint {
}

APFixedPoint(uint64_t Val, const FixedPointSemantics &Sema)
: APFixedPoint(APInt(Sema.getWidth(), Val, Sema.isSigned()), Sema) {}
: APFixedPoint(APInt(Sema.getWidth(), Val, Sema.isSigned(),
/*implicitTrunc=*/true),
Sema) {}

// Zero initialization.
APFixedPoint(const FixedPointSemantics &Sema) : APFixedPoint(0, Sema) {}
Expand Down
19 changes: 17 additions & 2 deletions llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,26 @@ class [[nodiscard]] APInt {
/// \param numBits the bit width of the constructed APInt
/// \param val the initial value of the APInt
/// \param isSigned how to treat signedness of val
APInt(unsigned numBits, uint64_t val, bool isSigned = false)
/// \param implicitTrunc allow implicit truncation of non-zero/sign bits of
/// val beyond the range of numBits
APInt(unsigned numBits, uint64_t val, bool isSigned = false,
bool implicitTrunc = false)
: BitWidth(numBits) {
if (!implicitTrunc) {
if (BitWidth == 0) {
assert(val == 0 && "Value must be zero for 0-bit APInt");
} else if (isSigned) {
assert(llvm::isIntN(BitWidth, val) &&
"Value is not an N-bit signed value");
} else {
assert(llvm::isUIntN(BitWidth, val) &&
"Value is not an N-bit unsigned value");
}
}
if (isSingleWord()) {
U.VAL = val;
clearUnusedBits();
if (implicitTrunc || isSigned)
clearUnusedBits();
} else {
initSlowCase(val, isSigned);
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
SrcElemTy,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
ArrayRef((Value *const *)Ops.data() + 1, Ops.size() - 1)),
true);
/*isSigned=*/true, /*implicitTrunc=*/true);
// EraVM local end
Ptr = StripPtrCastKeepAS(Ptr);

Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Analysis/Loads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ static bool isDereferenceableAndAlignedPointer(
}

bool CheckForNonNull, CheckForFreed;
APInt KnownDerefBytes(Size.getBitWidth(),
V->getPointerDereferenceableBytes(DL, CheckForNonNull,
CheckForFreed));
if (KnownDerefBytes.getBoolValue() && KnownDerefBytes.uge(Size) &&
if (Size.ule(V->getPointerDereferenceableBytes(DL, CheckForNonNull,
CheckForFreed)) &&
!CheckForFreed)
if (!CheckForNonNull || isKnownNonZero(V, DL, 0, AC, CtxI, DT)) {
// As we recursed through GEPs to get here, we've incrementally checked
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Analysis/MemoryBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitAllocaInst(AllocaInst &I) {
TypeSize ElemSize = DL.getTypeAllocSize(I.getAllocatedType());
if (ElemSize.isScalable() && Options.EvalMode != ObjectSizeOpts::Mode::Min)
return unknown();
if (!isUIntN(IntTyBits, ElemSize.getKnownMinValue()))
return unknown();
APInt Size(IntTyBits, ElemSize.getKnownMinValue());
if (!I.isArrayAllocation())
return std::make_pair(align(Size, I.getAlign()), Zero);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8562,7 +8562,7 @@ static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II) {
case Intrinsic::cttz:
// Maximum of set/clear bits is the bit width.
return ConstantRange::getNonEmpty(APInt::getZero(Width),
APInt(Width, Width + 1));
APInt(Width, Width) + 1);
case Intrinsic::uadd_sat:
// uadd.sat(x, C) produces [C, UINT_MAX].
if (match(II.getOperand(0), m_APInt(C)) ||
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1552,7 +1552,10 @@ SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
assert((EltVT.getSizeInBits() >= 64 ||
(uint64_t)((int64_t)Val >> EltVT.getSizeInBits()) + 1 < 2) &&
"getConstant with a uint64_t value that doesn't fit in the type!");
return getConstant(APInt(EltVT.getSizeInBits(), Val), DL, VT, isT, isO);
// TODO: Avoid implicit trunc?
return getConstant(APInt(EltVT.getSizeInBits(), Val, /*isSigned=*/false,
/*implicitTrunc=*/true),
DL, VT, isT, isO);
}

SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4028,7 +4028,8 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
DAG.getDataLayout().getTypeAllocSize(GTI.getIndexedType());
// We intentionally mask away the high bits here; ElementSize may not
// fit in IdxTy.
APInt ElementMul(IdxSize, ElementSize.getKnownMinValue());
APInt ElementMul(IdxSize, ElementSize.getKnownMinValue(),
/*isSigned=*/false, /*implicitTrunc=*/true);
bool ElementScalable = ElementSize.isScalable();

// If this is a scalar constant or a splat vector of constants,
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2004,7 +2004,9 @@ ScheduleDAGSDNodes *SelectionDAGISel::CreateScheduler() {
bool SelectionDAGISel::CheckAndMask(SDValue LHS, ConstantSDNode *RHS,
int64_t DesiredMaskS) const {
const APInt &ActualMask = RHS->getAPIntValue();
const APInt &DesiredMask = APInt(LHS.getValueSizeInBits(), DesiredMaskS);
// TODO: Avoid implicit trunc?
const APInt &DesiredMask = APInt(LHS.getValueSizeInBits(), DesiredMaskS,
/*isSigned=*/false, /*implicitTrunc=*/true);

// If the actual mask exactly matches, success!
if (ActualMask == DesiredMask)
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6442,7 +6442,9 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,

PAmts.push_back(DAG.getConstant(P, DL, SVT));
KAmts.push_back(
DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT));
DAG.getConstant(APInt(ShSVT.getSizeInBits(), K, /*isSigned=*/false,
/*implicitTrunc=*/true),
DL, ShSVT));
QAmts.push_back(DAG.getConstant(Q, DL, SVT));
return true;
};
Expand Down Expand Up @@ -6696,7 +6698,9 @@ TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
PAmts.push_back(DAG.getConstant(P, DL, SVT));
AAmts.push_back(DAG.getConstant(A, DL, SVT));
KAmts.push_back(
DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT));
DAG.getConstant(APInt(ShSVT.getSizeInBits(), K, /*isSigned=*/false,
/*implicitTrunc=*/true),
DL, ShSVT));
QAmts.push_back(DAG.getConstant(Q, DL, SVT));
return true;
};
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/ExecutionEngine/MCJIT/MCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ GenericValue MCJIT::runFunction(Function *F, ArrayRef<GenericValue> ArgValues) {
return rv;
}
case Type::VoidTyID:
rv.IntVal = APInt(32, ((int(*)())(intptr_t)FPtr)());
rv.IntVal = APInt(32, ((int (*)())(intptr_t)FPtr)(), true);
return rv;
case Type::FloatTyID:
rv.FloatVal = ((float(*)())(intptr_t)FPtr)();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,7 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
// Zero is either safe or not in the range. The output range is composed by
// the result of countLeadingZero of the two extremes.
return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
APInt(getBitWidth(), getUnsignedMin().countl_zero()) + 1);
}

ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,9 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
}

ConstantInt *ConstantInt::get(IntegerType *Ty, uint64_t V, bool isSigned) {
return get(Ty->getContext(), APInt(Ty->getBitWidth(), V, isSigned));
// TODO: Avoid implicit trunc?
return get(Ty->getContext(),
APInt(Ty->getBitWidth(), V, isSigned, /*implicitTrunc=*/true));
}

Constant *ConstantInt::get(Type *Ty, const APInt& V) {
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/Support/APInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ APInt& APInt::operator-=(uint64_t RHS) {
APInt APInt::operator*(const APInt& RHS) const {
assert(BitWidth == RHS.BitWidth && "Bit widths must be the same");
if (isSingleWord())
return APInt(BitWidth, U.VAL * RHS.U.VAL);
return APInt(BitWidth, U.VAL * RHS.U.VAL, /*isSigned=*/false,
/*implicitTrunc=*/true);

APInt Result(getMemory(getNumWords()), getBitWidth());
tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords());
Expand Down Expand Up @@ -455,15 +456,17 @@ APInt APInt::extractBits(unsigned numBits, unsigned bitPosition) const {
"Illegal bit extraction");

if (isSingleWord())
return APInt(numBits, U.VAL >> bitPosition);
return APInt(numBits, U.VAL >> bitPosition, /*isSigned=*/false,
/*implicitTrunc=*/true);

unsigned loBit = whichBit(bitPosition);
unsigned loWord = whichWord(bitPosition);
unsigned hiWord = whichWord(bitPosition + numBits - 1);

// Single word result extracting bits from a single word source.
if (loWord == hiWord)
return APInt(numBits, U.pVal[loWord] >> loBit);
return APInt(numBits, U.pVal[loWord] >> loBit, /*isSigned=*/false,
/*implicitTrunc=*/true);

// Extracting bits that start on a source word boundary can be done
// as a fast memory copy.
Expand Down Expand Up @@ -907,7 +910,8 @@ APInt APInt::trunc(unsigned width) const {
assert(width <= BitWidth && "Invalid APInt Truncate request");

if (width <= APINT_BITS_PER_WORD)
return APInt(width, getRawData()[0]);
return APInt(width, getRawData()[0], /*isSigned=*/false,
/*implicitTrunc=*/true);

if (width == BitWidth)
return *this;
Expand Down Expand Up @@ -955,7 +959,7 @@ APInt APInt::sext(unsigned Width) const {
assert(Width >= BitWidth && "Invalid APInt SignExtend request");

if (Width <= APINT_BITS_PER_WORD)
return APInt(Width, SignExtend64(U.VAL, BitWidth));
return APInt(Width, SignExtend64(U.VAL, BitWidth), /*isSigned*/ true);

if (Width == BitWidth)
return *this;
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
Value *Ptr, Type *ResElemTy, int64_t Offset) {
if (Offset != 0) {
APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset);
APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset,
/*isSigned=*/true);
Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(APOffset));
}
return Ptr;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6753,7 +6753,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,

for (auto Case : SI->cases()) {
auto *Orig = Case.getCaseValue();
auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base);
auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base, true);
Case.setValue(
cast<ConstantInt>(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue()))));
}
Expand Down
9 changes: 5 additions & 4 deletions llvm/unittests/ADT/APFixedPointTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,27 +240,28 @@ void CheckIntPart(const FixedPointSemantics &Sema, int64_t IntPart) {
APFixedPoint ValWithFract(
APInt(Sema.getWidth(),
relativeShr(IntPart, Sema.getLsbWeight()) + FullFactPart,
Sema.isSigned()),
Sema.isSigned(), /*implicitTrunc=*/true),
Sema);
ASSERT_EQ(ValWithFract.getIntPart(), IntPart);

// Just fraction
APFixedPoint JustFract(APInt(Sema.getWidth(), FullFactPart, Sema.isSigned()),
APFixedPoint JustFract(APInt(Sema.getWidth(), FullFactPart, Sema.isSigned(),
/*implicitTrunc=*/true),
Sema);
ASSERT_EQ(JustFract.getIntPart(), 0);

// Whole number
APFixedPoint WholeNum(APInt(Sema.getWidth(),
relativeShr(IntPart, Sema.getLsbWeight()),
Sema.isSigned()),
Sema.isSigned(), /*implicitTrunc=*/true),
Sema);
ASSERT_EQ(WholeNum.getIntPart(), IntPart);

// Negative
if (Sema.isSigned()) {
APFixedPoint Negative(APInt(Sema.getWidth(),
relativeShr(IntPart, Sema.getLsbWeight()),
Sema.isSigned()),
Sema.isSigned(), /*implicitTrunc=*/true),
Sema);
ASSERT_EQ(Negative.getIntPart(), IntPart);
}
Expand Down
Loading

0 comments on commit f9f732c

Please sign in to comment.