diff --git a/src/script/drivechain.h b/src/script/drivechain.h index bcf1b4e47..fd5d616af 100644 --- a/src/script/drivechain.h +++ b/src/script/drivechain.h @@ -12,9 +12,52 @@ #include #include +#define LIMITED_VECTOR(obj, n) REF(MakeLimitedVector(obj)) + const unsigned char ACK_LABEL[] = {0x41, 0x43, 0x4B, 0x3A}; // "ACK:" const size_t ACK_LABEL_LENGTH = sizeof(ACK_LABEL); +template +class LimitedVector +{ +protected: + std::vector& vec; + +public: + LimitedVector(std::vector& vec) : vec(vec) {} + + template + void Unserialize(Stream& s, int, int = 0) + { + size_t size = ReadCompactSize(s); + if (size > Limit) { + throw std::ios_base::failure("String length limit exceeded"); + } + vec.resize(size); + if (size != 0) + s.read((char*)&vec[0], size); + } + + template + void Serialize(Stream& s, int, int = 0) const + { + WriteCompactSize(s, vec.size()); + if (!vec.empty()) + s.write((char*)&vec[0], vec.size()); + } + + unsigned int GetSerializeSize(int, int = 0) const + { + return GetSizeOfCompactSize(vec.size()) + vec.size(); + } +}; + +template +LimitedVector MakeLimitedVector(std::vector& obj) +{ + return LimitedVector(obj); +} + class Ack { public: @@ -28,26 +71,28 @@ public: { uint64_t nPayload = 0; if (!ser_action.ForRead()) - nPayload = CalcPayloadSize(); + nPayload = CalcPayloadSize(nType, nVersion); READWRITE(COMPACTSIZE(nPayload)); - READWRITE(prefix); + if (nPayload == 0) + throw std::runtime_error("Not valid ACK"); + READWRITE(LIMITED_VECTOR(prefix, 32)); // Empty preimage should not be serialized if (ser_action.ForRead()) { uint64_t nPrefix = prefix.size(); nPrefix += GetSizeOfCompactSize(nPrefix); if (nPayload > nPrefix) - READWRITE(preimage); - if (CalcPayloadSize() != nPayload) + READWRITE(LIMITED_VECTOR(preimage, 32)); + if (CalcPayloadSize(nType, nVersion) != nPayload) throw std::runtime_error("Not valid ACK"); } else { if (preimage.size() > 0) - READWRITE(preimage); + READWRITE(LIMITED_VECTOR(preimage, 32)); } } - uint64_t CalcPayloadSize() const + unsigned int CalcPayloadSize(int nType, int nVersion) const { - uint64_t nPayload = 0; + unsigned int nPayload = 0; nPayload += GetSizeOfCompactSize(prefix.size()); nPayload += prefix.size(); // Empty preimage should not be serialized @@ -58,12 +103,6 @@ public: return nPayload; } - uint64_t CalcSize() const - { - uint64_t nPayload = CalcPayloadSize(); - return GetSizeOfCompactSize(nPayload) + nPayload; - } - Ack() {} Ack(std::vector prefix, std::vector preimage = std::vector()) : prefix(prefix), preimage(preimage) @@ -83,14 +122,14 @@ public: { uint64_t sizePayload = 0; if (!ser_action.ForRead()) - sizePayload = CalcPayloadSize(); + sizePayload = CalcPayloadSize(nType, nVersion); READWRITE(COMPACTSIZE(sizePayload)); if (ser_action.ForRead()) { - uint64_t read = 0; + unsigned int read = 0; while (read < sizePayload) { Ack ack; READWRITE(ack); - read += ack.CalcSize(); + read += ack.GetSerializeSize(nType, nVersion); vAck.push_back(ack); } if (read != sizePayload) @@ -102,24 +141,15 @@ public: } } - uint64_t CalcPayloadSize() const + unsigned int CalcPayloadSize(int nType, int nVersion) const { - uint64_t nPayload = 0; + unsigned int nPayload = 0; for (const Ack& ack : vAck) { - nPayload += ack.CalcSize(); + nPayload += ack.GetSerializeSize(nType, nVersion); } return nPayload; } - uint64_t CalcSize() const - { - uint64_t nSize = 0; - uint64_t nPayloadSize = CalcPayloadSize(); - nSize += GetSizeOfCompactSize(nPayloadSize); - nSize += nPayloadSize; - return nSize; - } - AckList() {} AckList(std::vector acks) : vAck(acks) {} }; @@ -137,32 +167,23 @@ public: { uint64_t nPayload = 0; if (!ser_action.ForRead()) - nPayload = CalcPayloadSize(); + nPayload = CalcPayloadSize(nType, nVersion); READWRITE(COMPACTSIZE(nPayload)); - READWRITE(chainId); + READWRITE(LIMITED_VECTOR(chainId, 20)); READWRITE(ackList); - if (ser_action.ForRead() && nPayload != CalcPayloadSize()) + if (ser_action.ForRead() && nPayload != CalcPayloadSize(nType, nVersion)) throw std::runtime_error("Not valid CHAIN ACK LIST"); } - uint64_t CalcPayloadSize() const + unsigned int CalcPayloadSize(int nType, int nVersion) const { - uint64_t nPayload = 0; + unsigned int nPayload = 0; nPayload += GetSizeOfCompactSize(chainId.size()); nPayload += chainId.size(); - nPayload += ackList.CalcSize(); + nPayload += ackList.GetSerializeSize(nType, nVersion); return nPayload; } - uint64_t CalcSize() const - { - uint64_t nSize = 0; - uint64_t nPayloadSize = CalcPayloadSize(); - nSize += GetSizeOfCompactSize(nPayloadSize); - nSize += nPayloadSize; - return nSize; - } - ChainAckList& operator<<(Ack ack) { ackList.vAck.push_back(ack); @@ -185,14 +206,14 @@ public: { uint64_t sizePayload = 0; if (!ser_action.ForRead()) - sizePayload = CalcPayloadSize(); + sizePayload = CalcPayloadSize(nType, nVersion); READWRITE(COMPACTSIZE(sizePayload)); if (ser_action.ForRead()) { uint64_t read = 0; while (read < sizePayload) { ChainAckList chainAcks; READWRITE(chainAcks); - read += chainAcks.CalcSize(); + read += chainAcks.GetSerializeSize(nType, nVersion); vChainAcks.push_back(chainAcks); } if (read != sizePayload) @@ -204,24 +225,15 @@ public: } } - uint64_t CalcPayloadSize() const + unsigned int CalcPayloadSize(int nType, int nVersion) const { - uint64_t nPayloadSize = 0; + unsigned int nPayloadSize = 0; for (const auto& chainAcks : vChainAcks) { - nPayloadSize += chainAcks.CalcSize(); + nPayloadSize += chainAcks.GetSerializeSize(nType, nVersion); } return nPayloadSize; } - uint64_t CalcSize() const - { - uint64_t nSize = 0; - uint64_t nPayloadSize = CalcPayloadSize(); - nSize += GetSizeOfCompactSize(nPayloadSize); - nSize += nPayloadSize; - return nSize; - } - FullAckList& operator<<(Ack ack) { if (!vChainAcks.empty()) {