From 680c0a3c94f0bb84a2773bc9a95dc5399b6925fb Mon Sep 17 00:00:00 2001 From: pacien Date: Sun, 25 Nov 2018 16:45:35 +0100 Subject: Fix bitreader look-ahead overflow --- src/bitreader.nim | 52 ++++++++++++++++++---------------------------------- src/integers.nim | 7 +++++-- tests/tbitreader.nim | 23 +++++++++++++++++++++++ tests/tintegers.nim | 7 +++++++ 4 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/bitreader.nim b/src/bitreader.nim index 757c1b3..7afb13d 100644 --- a/src/bitreader.nim +++ b/src/bitreader.nim @@ -17,49 +17,33 @@ import streams import integers -# Stream functions - -proc newEIO(msg: string): ref IOError = - new(result) - result.msg = msg - -proc read[T](s: Stream, t: typedesc[T]): T = - if readData(s, addr(result), sizeof(T)) != sizeof(T): - raise newEIO("cannot read from stream") - -proc peek[T](s: Stream, t: typedesc[T]): T = - if peekData(s, addr(result), sizeof(T)) != sizeof(T): - raise newEIO("cannot read from stream") - -# BitReader - type BitReader* = ref object stream: Stream bitOffset: int + overflowBuffer: uint8 proc bitReader*(stream: Stream): BitReader = - BitReader(stream: stream, bitOffset: 0) + BitReader(stream: stream, bitOffset: 0, overflowBuffer: 0) proc atEnd*(bitReader: BitReader): bool = - bitReader.stream.atEnd() + bitReader.bitOffset == 0 and bitReader.stream.atEnd() proc readBits*[T: SomeUnsignedInt](bitReader: BitReader, bits: int, to: typedesc[T]): T = - let targetBitLength = sizeof(T) * wordBitLength - if bits < 0 or bits > targetBitLength: - raise newException(RangeError, "invalid bit length") - elif bits == 0: - result = 0 - elif bits < targetBitLength - bitReader.bitOffset: - result = bitReader.stream.peek(T) shl (targetBitLength - bits - bitReader.bitOffset) shr (targetBitLength - bits) - elif bits == targetBitLength - bitReader.bitOffset: - result = bitReader.stream.read(T) shl (targetBitLength - bits - bitReader.bitOffset) shr (targetBitLength - bits) - else: - let rightBits = targetBitLength - bitReader.bitOffset - let leftBits = bits - rightBits - let right = bitReader.stream.read(T) shr bitReader.bitOffset - let left = bitReader.stream.peek(T) shl (targetBitLength - leftBits) shr (targetBitLength - bits) - result = left or right - bitReader.bitOffset = (bitReader.bitOffset + bits) mod wordBitLength + if bits < 0 or bits > sizeof(T) * wordBitLength: raise newException(RangeError, "invalid bit length") + if bits == 0: return 0 + var bitsRead = 0 + if bitReader.bitOffset > 0: + let bitsFromBuffer = min(bits, wordBitLength - bitReader.bitOffset) + result = (bitReader.overflowBuffer shr bitReader.bitOffset).leastSignificantBits(bitsFromBuffer) + bitReader.bitOffset = (bitReader.bitOffset + bitsFromBuffer) mod wordBitLength + bitsRead += bitsFromBuffer + while bits - bitsRead >= wordBitLength: + result = result or (bitReader.stream.readUint8().T shl bitsRead) + bitsRead += wordBitLength + if bits - bitsRead > 0: + bitReader.overflowBuffer = bitReader.stream.readUint8() + bitReader.bitOffset = bits - bitsRead + result = result or (bitReader.overflowBuffer.leastSignificantBits(bitReader.bitOffset).T shl bitsRead) proc readBool*(bitReader: BitReader): bool = bitReader.readBits(1, uint8) != 0 diff --git a/src/integers.nim b/src/integers.nim index fddbfdc..7b0f166 100644 --- a/src/integers.nim +++ b/src/integers.nim @@ -15,13 +15,16 @@ # along with this program. If not, see . const wordBitLength* = 8 -const wordBitMask* = 0b1111_1111'u8 proc `/^`*[T: Natural](x, y: T): T = (x + y - 1) div y proc truncateToUint8*(x: SomeUnsignedInt): uint8 = - (x and wordBitMask).uint8 + (x and uint8.high).uint8 + +proc leastSignificantBits*[T: SomeUnsignedInt](x: T, bits: int): T = + let maskOffset = sizeof(T) * wordBitLength - bits + if maskOffset >= 0: (x shl maskOffset) shr maskOffset else: x iterator chunks*(totalBitLength: int, chunkType: typedesc[SomeInteger]): tuple[index: int, chunkBitLength: int] = let chunkBitLength = sizeof(chunkType) * wordBitLength diff --git a/tests/tbitreader.nim b/tests/tbitreader.nim index 8285f63..294f6c9 100644 --- a/tests/tbitreader.nim +++ b/tests/tbitreader.nim @@ -49,6 +49,29 @@ suite "bitreader": expect IOError: discard bitReader.readBits(16, uint16) check bitReader.atEnd() + test "readBits (look-ahead overflow)": + let stream = newStringStream() + defer: stream.close() + stream.write(0xAB'u8) + stream.setPosition(0) + + let bitReader = stream.bitReader() + check bitReader.readBits(4, uint16) == 0x000B'u16 + check bitReader.readBits(4, uint16) == 0x000A'u16 + check bitReader.atEnd() + + test "readBits (from buffer composition)": + let stream = newStringStream() + defer: stream.close() + stream.write(0xABCD'u16) + stream.setPosition(0) + + let bitReader = stream.bitReader() + check bitReader.readBits(4, uint16) == 0x000D'u16 + check bitReader.readBits(8, uint16) == 0x00BC'u16 + check bitReader.readBits(4, uint16) == 0x000A'u16 + check bitReader.atEnd() + test "readSeq": let stream = newStringStream() defer: stream.close() diff --git a/tests/tintegers.nim b/tests/tintegers.nim index c77abec..956e4aa 100644 --- a/tests/tintegers.nim +++ b/tests/tintegers.nim @@ -27,6 +27,13 @@ suite "integers": check truncateToUint8(0x00FA'u16) == 0xFA'u8 check truncateToUint8(0xFFFA'u16) == 0xFA'u8 + test "leastSignificantBits": + check leastSignificantBits(0xFF'u8, 3) == 0b0000_0111'u8 + check leastSignificantBits(0b0001_0101'u8, 3) == 0b0000_0101'u8 + check leastSignificantBits(0xFF'u8, 10) == 0xFF'u8 + check leastSignificantBits(0xFFFF'u16, 16) == 0xFFFF'u16 + check leastSignificantBits(0xFFFF'u16, 8) == 0x00FF'u16 + test "chunks iterator": check toSeq(chunks(70, uint32)) == @[(0, 32), (1, 32), (2, 6)] check toSeq(chunks(32, uint16)) == @[(0, 16), (1, 16)] -- cgit v1.2.3