Enforce limits to fields when deserializing
This commit is contained in:
@@ -12,9 +12,52 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#define LIMITED_VECTOR(obj, n) REF(MakeLimitedVector<n>(obj))
|
||||
|
||||
const unsigned char ACK_LABEL[] = {0x41, 0x43, 0x4B, 0x3A}; // "ACK:"
|
||||
const size_t ACK_LABEL_LENGTH = sizeof(ACK_LABEL);
|
||||
|
||||
template <size_t Limit, typename U>
|
||||
class LimitedVector
|
||||
{
|
||||
protected:
|
||||
std::vector<U>& vec;
|
||||
|
||||
public:
|
||||
LimitedVector(std::vector<U>& vec) : vec(vec) {}
|
||||
|
||||
template <typename Stream>
|
||||
void Unserialize(Stream& s, int, int = 0)
|
||||
{
|
||||
size_t size = ReadCompactSize(s);
|
||||
if (size > Limit) {
|
||||
throw std::ios_base::failure("String length limit exceeded");
|
||||
}
|
||||
vec.resize(size);
|
||||
if (size != 0)
|
||||
s.read((char*)&vec[0], size);
|
||||
}
|
||||
|
||||
template <typename Stream>
|
||||
void Serialize(Stream& s, int, int = 0) const
|
||||
{
|
||||
WriteCompactSize(s, vec.size());
|
||||
if (!vec.empty())
|
||||
s.write((char*)&vec[0], vec.size());
|
||||
}
|
||||
|
||||
unsigned int GetSerializeSize(int, int = 0) const
|
||||
{
|
||||
return GetSizeOfCompactSize(vec.size()) + vec.size();
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t N, typename U>
|
||||
LimitedVector<N, U> MakeLimitedVector(std::vector<U>& obj)
|
||||
{
|
||||
return LimitedVector<N, U>(obj);
|
||||
}
|
||||
|
||||
class Ack
|
||||
{
|
||||
public:
|
||||
@@ -28,26 +71,28 @@ public:
|
||||
{
|
||||
uint64_t nPayload = 0;
|
||||
if (!ser_action.ForRead())
|
||||
nPayload = CalcPayloadSize();
|
||||
nPayload = CalcPayloadSize(nType, nVersion);
|
||||
READWRITE(COMPACTSIZE(nPayload));
|
||||
READWRITE(prefix);
|
||||
if (nPayload == 0)
|
||||
throw std::runtime_error("Not valid ACK");
|
||||
READWRITE(LIMITED_VECTOR(prefix, 32));
|
||||
// Empty preimage should not be serialized
|
||||
if (ser_action.ForRead()) {
|
||||
uint64_t nPrefix = prefix.size();
|
||||
nPrefix += GetSizeOfCompactSize(nPrefix);
|
||||
if (nPayload > nPrefix)
|
||||
READWRITE(preimage);
|
||||
if (CalcPayloadSize() != nPayload)
|
||||
READWRITE(LIMITED_VECTOR(preimage, 32));
|
||||
if (CalcPayloadSize(nType, nVersion) != nPayload)
|
||||
throw std::runtime_error("Not valid ACK");
|
||||
} else {
|
||||
if (preimage.size() > 0)
|
||||
READWRITE(preimage);
|
||||
READWRITE(LIMITED_VECTOR(preimage, 32));
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t CalcPayloadSize() const
|
||||
unsigned int CalcPayloadSize(int nType, int nVersion) const
|
||||
{
|
||||
uint64_t nPayload = 0;
|
||||
unsigned int nPayload = 0;
|
||||
nPayload += GetSizeOfCompactSize(prefix.size());
|
||||
nPayload += prefix.size();
|
||||
// Empty preimage should not be serialized
|
||||
@@ -58,12 +103,6 @@ public:
|
||||
return nPayload;
|
||||
}
|
||||
|
||||
uint64_t CalcSize() const
|
||||
{
|
||||
uint64_t nPayload = CalcPayloadSize();
|
||||
return GetSizeOfCompactSize(nPayload) + nPayload;
|
||||
}
|
||||
|
||||
Ack() {}
|
||||
Ack(std::vector<unsigned char> prefix, std::vector<unsigned char> preimage = std::vector<unsigned char>())
|
||||
: prefix(prefix), preimage(preimage)
|
||||
@@ -83,14 +122,14 @@ public:
|
||||
{
|
||||
uint64_t sizePayload = 0;
|
||||
if (!ser_action.ForRead())
|
||||
sizePayload = CalcPayloadSize();
|
||||
sizePayload = CalcPayloadSize(nType, nVersion);
|
||||
READWRITE(COMPACTSIZE(sizePayload));
|
||||
if (ser_action.ForRead()) {
|
||||
uint64_t read = 0;
|
||||
unsigned int read = 0;
|
||||
while (read < sizePayload) {
|
||||
Ack ack;
|
||||
READWRITE(ack);
|
||||
read += ack.CalcSize();
|
||||
read += ack.GetSerializeSize(nType, nVersion);
|
||||
vAck.push_back(ack);
|
||||
}
|
||||
if (read != sizePayload)
|
||||
@@ -102,24 +141,15 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t CalcPayloadSize() const
|
||||
unsigned int CalcPayloadSize(int nType, int nVersion) const
|
||||
{
|
||||
uint64_t nPayload = 0;
|
||||
unsigned int nPayload = 0;
|
||||
for (const Ack& ack : vAck) {
|
||||
nPayload += ack.CalcSize();
|
||||
nPayload += ack.GetSerializeSize(nType, nVersion);
|
||||
}
|
||||
return nPayload;
|
||||
}
|
||||
|
||||
uint64_t CalcSize() const
|
||||
{
|
||||
uint64_t nSize = 0;
|
||||
uint64_t nPayloadSize = CalcPayloadSize();
|
||||
nSize += GetSizeOfCompactSize(nPayloadSize);
|
||||
nSize += nPayloadSize;
|
||||
return nSize;
|
||||
}
|
||||
|
||||
AckList() {}
|
||||
AckList(std::vector<Ack> acks) : vAck(acks) {}
|
||||
};
|
||||
@@ -137,32 +167,23 @@ public:
|
||||
{
|
||||
uint64_t nPayload = 0;
|
||||
if (!ser_action.ForRead())
|
||||
nPayload = CalcPayloadSize();
|
||||
nPayload = CalcPayloadSize(nType, nVersion);
|
||||
READWRITE(COMPACTSIZE(nPayload));
|
||||
READWRITE(chainId);
|
||||
READWRITE(LIMITED_VECTOR(chainId, 20));
|
||||
READWRITE(ackList);
|
||||
if (ser_action.ForRead() && nPayload != CalcPayloadSize())
|
||||
if (ser_action.ForRead() && nPayload != CalcPayloadSize(nType, nVersion))
|
||||
throw std::runtime_error("Not valid CHAIN ACK LIST");
|
||||
}
|
||||
|
||||
uint64_t CalcPayloadSize() const
|
||||
unsigned int CalcPayloadSize(int nType, int nVersion) const
|
||||
{
|
||||
uint64_t nPayload = 0;
|
||||
unsigned int nPayload = 0;
|
||||
nPayload += GetSizeOfCompactSize(chainId.size());
|
||||
nPayload += chainId.size();
|
||||
nPayload += ackList.CalcSize();
|
||||
nPayload += ackList.GetSerializeSize(nType, nVersion);
|
||||
return nPayload;
|
||||
}
|
||||
|
||||
uint64_t CalcSize() const
|
||||
{
|
||||
uint64_t nSize = 0;
|
||||
uint64_t nPayloadSize = CalcPayloadSize();
|
||||
nSize += GetSizeOfCompactSize(nPayloadSize);
|
||||
nSize += nPayloadSize;
|
||||
return nSize;
|
||||
}
|
||||
|
||||
ChainAckList& operator<<(Ack ack)
|
||||
{
|
||||
ackList.vAck.push_back(ack);
|
||||
@@ -185,14 +206,14 @@ public:
|
||||
{
|
||||
uint64_t sizePayload = 0;
|
||||
if (!ser_action.ForRead())
|
||||
sizePayload = CalcPayloadSize();
|
||||
sizePayload = CalcPayloadSize(nType, nVersion);
|
||||
READWRITE(COMPACTSIZE(sizePayload));
|
||||
if (ser_action.ForRead()) {
|
||||
uint64_t read = 0;
|
||||
while (read < sizePayload) {
|
||||
ChainAckList chainAcks;
|
||||
READWRITE(chainAcks);
|
||||
read += chainAcks.CalcSize();
|
||||
read += chainAcks.GetSerializeSize(nType, nVersion);
|
||||
vChainAcks.push_back(chainAcks);
|
||||
}
|
||||
if (read != sizePayload)
|
||||
@@ -204,24 +225,15 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t CalcPayloadSize() const
|
||||
unsigned int CalcPayloadSize(int nType, int nVersion) const
|
||||
{
|
||||
uint64_t nPayloadSize = 0;
|
||||
unsigned int nPayloadSize = 0;
|
||||
for (const auto& chainAcks : vChainAcks) {
|
||||
nPayloadSize += chainAcks.CalcSize();
|
||||
nPayloadSize += chainAcks.GetSerializeSize(nType, nVersion);
|
||||
}
|
||||
return nPayloadSize;
|
||||
}
|
||||
|
||||
uint64_t CalcSize() const
|
||||
{
|
||||
uint64_t nSize = 0;
|
||||
uint64_t nPayloadSize = CalcPayloadSize();
|
||||
nSize += GetSizeOfCompactSize(nPayloadSize);
|
||||
nSize += nPayloadSize;
|
||||
return nSize;
|
||||
}
|
||||
|
||||
FullAckList& operator<<(Ack ack)
|
||||
{
|
||||
if (!vChainAcks.empty()) {
|
||||
|
||||
Reference in New Issue
Block a user