diff options
author | pacien | 2018-11-29 12:41:59 +0100 |
---|---|---|
committer | pacien | 2018-11-29 12:42:25 +0100 |
commit | 1524ab71168b7c214a531f796c94962776e9d88a (patch) | |
tree | 9d4e7342ea0a0d8a534e474f40d6eabd0c108b91 | |
parent | d661132528d5c27148a0b55d52709ce97124000a (diff) | |
download | gziplike-1524ab71168b7c214a531f796c94962776e9d88a.tar.gz |
add generic huffman tree builder
-rw-r--r-- | src/huffmantree.nim | 31 | ||||
-rw-r--r-- | src/lzsschain.nim | 12 | ||||
-rw-r--r-- | tests/thuffmantree.nim | 33 | ||||
-rw-r--r-- | tests/tlzsschain.nim | 18 |
4 files changed, 84 insertions, 10 deletions
diff --git a/src/huffmantree.nim b/src/huffmantree.nim index 1711879..adcaec7 100644 --- a/src/huffmantree.nim +++ b/src/huffmantree.nim | |||
@@ -14,6 +14,7 @@ | |||
14 | # You should have received a copy of the GNU Affero General Public License | 14 | # You should have received a copy of the GNU Affero General Public License |
15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. | 15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. |
16 | 16 | ||
17 | import tables, heapqueue | ||
17 | import integers, bitreader, bitwriter | 18 | import integers, bitreader, bitwriter |
18 | 19 | ||
19 | const valueLengthFieldBitLength* = 6 # 64 | 20 | const valueLengthFieldBitLength* = 6 # 64 |
@@ -28,19 +29,32 @@ type HuffmanTreeNode*[T: SomeUnsignedInt] = ref object | |||
28 | left, right: HuffmanTreeNode[T] | 29 | left, right: HuffmanTreeNode[T] |
29 | of leaf: | 30 | of leaf: |
30 | value: T | 31 | value: T |
32 | weight: int | ||
31 | 33 | ||
32 | proc huffmanBranch*[T](left, right: HuffmanTreeNode[T]): HuffmanTreeNode[T] = | 34 | proc huffmanBranch*[T](left, right: HuffmanTreeNode[T]): HuffmanTreeNode[T] = |
33 | HuffmanTreeNode[T](kind: branch, left: left, right: right) | 35 | HuffmanTreeNode[T](kind: branch, left: left, right: right, weight: left.weight + right.weight) |
34 | 36 | ||
35 | proc huffmanLeaf*[T](value: T): HuffmanTreeNode[T] = | 37 | proc huffmanLeaf*[T](value: T, weight = 0): HuffmanTreeNode[T] = |
36 | HuffmanTreeNode[T](kind: leaf, value: value) | 38 | HuffmanTreeNode[T](kind: leaf, value: value, weight: weight) |
37 | 39 | ||
38 | proc `==`*[T](a, b: HuffmanTreeNode[T]): bool = | 40 | proc `==`*[T](a, b: HuffmanTreeNode[T]): bool = |
39 | if a.kind != b.kind: return false | 41 | if a.kind != b.kind or a.weight != b.weight: return false |
40 | case a.kind: | 42 | case a.kind: |
41 | of branch: a.left == b.left and a.right == b.right | 43 | of branch: a.left == b.left and a.right == b.right |
42 | of leaf: a.value == b.value | 44 | of leaf: a.value == b.value |
43 | 45 | ||
46 | proc `~=`*[T](a, b: HuffmanTreeNode[T]): bool = | ||
47 | if a.kind != b.kind: return false | ||
48 | case a.kind: | ||
49 | of branch: a.left ~= b.left and a.right ~= b.right | ||
50 | of leaf: a.value == b.value | ||
51 | |||
52 | proc `!~`*[T](a, b: HuffmanTreeNode[T]): bool = | ||
53 | not (a ~= b) | ||
54 | |||
55 | proc `<`*[T](left, right: HuffmanTreeNode[T]): bool = | ||
56 | left.weight < right.weight | ||
57 | |||
44 | proc maxValue*[T](node: HuffmanTreeNode[T]): T = | 58 | proc maxValue*[T](node: HuffmanTreeNode[T]): T = |
45 | case node.kind: | 59 | case node.kind: |
46 | of branch: max(node.left.maxValue(), node.right.maxValue()) | 60 | of branch: max(node.left.maxValue(), node.right.maxValue()) |
@@ -68,3 +82,12 @@ proc serialise*[T](tree: HuffmanTreeNode[T], bitWriter: BitWriter) = | |||
68 | bitWriter.writeBits(valueBitLength, node.value) | 82 | bitWriter.writeBits(valueBitLength, node.value) |
69 | bitWriter.writeBits(valueLengthFieldBitLength, valueBitLength.uint8) | 83 | bitWriter.writeBits(valueLengthFieldBitLength, valueBitLength.uint8) |
70 | writeNode(tree) | 84 | writeNode(tree) |
85 | |||
86 | proc symbolQueue*[T](stats: CountTableRef[T]): HeapQueue[HuffmanTreeNode[T]] = | ||
87 | result = newHeapQueue[HuffmanTreeNode[T]]() | ||
88 | for item, count in stats.pairs: result.push(huffmanLeaf(item, count)) | ||
89 | |||
90 | proc buildHuffmanTree*[T: SomeUnsignedInt](stats: CountTableRef[T]): HuffmanTreeNode[T] = | ||
91 | var symbolQueue = symbolQueue(stats) | ||
92 | while symbolQueue.len > 1: symbolQueue.push(huffmanBranch(symbolQueue.pop(), symbolQueue.pop())) | ||
93 | result = symbolQueue[0] | ||
diff --git a/src/lzsschain.nim b/src/lzsschain.nim index 8203cb8..073aa5e 100644 --- a/src/lzsschain.nim +++ b/src/lzsschain.nim | |||
@@ -15,7 +15,7 @@ | |||
15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. | 15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. |
16 | 16 | ||
17 | import lists, tables, sugar | 17 | import lists, tables, sugar |
18 | import polyfill, integers, lzssnode | 18 | import polyfill, integers, lzssnode, huffmantree |
19 | 19 | ||
20 | const maxChainByteLength = 32_000 * wordBitLength | 20 | const maxChainByteLength = 32_000 * wordBitLength |
21 | 21 | ||
@@ -34,3 +34,13 @@ proc decode*(lzssChain: LzssChain): seq[uint8] = | |||
34 | of reference: | 34 | of reference: |
35 | let absolutePos = result.len - node.relativePos | 35 | let absolutePos = result.len - node.relativePos |
36 | result.add(result.toOpenArray(absolutePos, absolutePos + node.length - 1)) | 36 | result.add(result.toOpenArray(absolutePos, absolutePos + node.length - 1)) |
37 | |||
38 | proc stats*(lzssChain: LzssChain): tuple[characters: CountTableRef[uint8], lengths, positions: CountTableRef[int]] = | ||
39 | result = (newCountTable[uint8](), newCountTable[int](), newCountTable[int]()) | ||
40 | for node in lzssChain.items: | ||
41 | case node.kind: | ||
42 | of character: | ||
43 | result.characters.inc(node.character) | ||
44 | of reference: | ||
45 | result.lengths.inc(node.length) | ||
46 | result.positions.inc(node.relativePos) | ||
diff --git a/tests/thuffmantree.nim b/tests/thuffmantree.nim index ec40bdb..705ac17 100644 --- a/tests/thuffmantree.nim +++ b/tests/thuffmantree.nim | |||
@@ -14,24 +14,43 @@ | |||
14 | # You should have received a copy of the GNU Affero General Public License | 14 | # You should have received a copy of the GNU Affero General Public License |
15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. | 15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. |
16 | 16 | ||
17 | import unittest, streams | 17 | import unittest, streams, sequtils, tables, heapqueue |
18 | import bitreader, bitwriter, huffmantree | 18 | import bitreader, bitwriter, huffmantree |
19 | 19 | ||
20 | suite "huffmantree": | 20 | suite "huffmantree": |
21 | let stats = newCountTable(concat(repeat(1'u, 3), repeat(2'u, 1), repeat(3'u, 2))) | ||
21 | let tree = huffmanBranch( | 22 | let tree = huffmanBranch( |
22 | huffmanLeaf(1'u), | 23 | huffmanLeaf(1'u), |
23 | huffmanBranch( | 24 | huffmanBranch( |
24 | huffmanLeaf(2'u), | 25 | huffmanLeaf(2'u), |
25 | huffmanLeaf(3'u))) | 26 | huffmanLeaf(3'u))) |
26 | 27 | ||
28 | test "equivalence": | ||
29 | check huffmanLeaf(12'u) ~= huffmanLeaf(12'u) | ||
30 | check huffmanLeaf(12'u) ~= huffmanLeaf(12'u, 2) | ||
31 | check huffmanLeaf(12'u) !~ huffmanLeaf(21'u) | ||
32 | check huffmanLeaf(12'u) !~ huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(12'u)) | ||
33 | check huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) ~= huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) | ||
34 | check huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) !~ huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(1'u)) | ||
35 | check huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 1)) ~= huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 2)) | ||
36 | check huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 1)) !~ huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(12'u, 2)) | ||
37 | |||
27 | test "equality": | 38 | test "equality": |
28 | check huffmanLeaf(12'u) == huffmanLeaf(12'u) | 39 | check huffmanLeaf(12'u) == huffmanLeaf(12'u) |
29 | check huffmanLeaf(12'u) != huffmanLeaf(21'u) | 40 | check huffmanLeaf(12'u) != huffmanLeaf(21'u) |
30 | check huffmanLeaf(12'u) != huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(12'u)) | 41 | check huffmanLeaf(12'u) != huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(12'u)) |
31 | check huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) == huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) | 42 | check huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) == huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) |
32 | check huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) != huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(1'u)) | 43 | check huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(21'u)) != huffmanBranch(huffmanLeaf(12'u), huffmanLeaf(1'u)) |
44 | check huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 1)) == huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 1)) | ||
45 | check huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 1)) != huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 2)) | ||
33 | check tree == tree | 46 | check tree == tree |
34 | 47 | ||
48 | test "weight comparison": | ||
49 | check huffmanLeaf(12'u, 1) < huffmanLeaf(12'u, 2) | ||
50 | check huffmanLeaf(12'u, 2) > huffmanLeaf(12'u, 1) | ||
51 | check huffmanLeaf(12'u, 1) < huffmanLeaf(12'u, 1) == false | ||
52 | check huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 1)) < huffmanBranch(huffmanLeaf(12'u, 1), huffmanLeaf(21'u, 2)) | ||
53 | |||
35 | test "maxValue": | 54 | test "maxValue": |
36 | check tree.maxValue() == 3 | 55 | check tree.maxValue() == 3 |
37 | 56 | ||
@@ -52,7 +71,7 @@ suite "huffmantree": | |||
52 | 71 | ||
53 | stream.setPosition(0) | 72 | stream.setPosition(0) |
54 | let bitReader = stream.bitReader() | 73 | let bitReader = stream.bitReader() |
55 | check huffmantree.deserialise(bitReader, uint) == tree | 74 | check huffmantree.deserialise(bitReader, uint) ~= tree |
56 | 75 | ||
57 | test "serialise": | 76 | test "serialise": |
58 | let stream = newStringStream() | 77 | let stream = newStringStream() |
@@ -72,3 +91,13 @@ suite "huffmantree": | |||
72 | check bitReader.readBits(2, uint8) == 2 | 91 | check bitReader.readBits(2, uint8) == 2 |
73 | check bitReader.readBool() == true # 3 leaf | 92 | check bitReader.readBool() == true # 3 leaf |
74 | check bitReader.readBits(2, uint8) == 3 | 93 | check bitReader.readBits(2, uint8) == 3 |
94 | |||
95 | test "symbolQueue": | ||
96 | var symbolQueue = symbolQueue(stats) | ||
97 | check symbolQueue.len == 3 | ||
98 | check symbolQueue.pop() == huffmanLeaf(2'u, 1) | ||
99 | check symbolQueue.pop() == huffmanLeaf(3'u, 2) | ||
100 | check symbolQueue.pop() == huffmanLeaf(1'u, 3) | ||
101 | |||
102 | test "buildHuffmanTree": | ||
103 | check buildHuffmanTree(stats) ~= tree | ||
diff --git a/tests/tlzsschain.nim b/tests/tlzsschain.nim index 241a0f1..a8c2012 100644 --- a/tests/tlzsschain.nim +++ b/tests/tlzsschain.nim | |||
@@ -14,11 +14,11 @@ | |||
14 | # You should have received a copy of the GNU Affero General Public License | 14 | # You should have received a copy of the GNU Affero General Public License |
15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. | 15 | # along with this program. If not, see <https://www.gnu.org/licenses/>. |
16 | 16 | ||
17 | import unittest | 17 | import unittest, sequtils, tables |
18 | import polyfill, lzssnode, lzsschain | 18 | import polyfill, lzssnode, lzsschain |
19 | 19 | ||
20 | suite "lzsschain": | 20 | suite "lzsschain": |
21 | test "decode": | 21 | proc chain(): LzssChain = |
22 | let chainArray = [ | 22 | let chainArray = [ |
23 | lzssCharacter(0), lzssCharacter(1), lzssCharacter(2), | 23 | lzssCharacter(0), lzssCharacter(1), lzssCharacter(2), |
24 | lzssCharacter(3), lzssCharacter(4), lzssCharacter(5), | 24 | lzssCharacter(3), lzssCharacter(4), lzssCharacter(5), |
@@ -27,4 +27,16 @@ suite "lzsschain": | |||
27 | lzssReference(3, 3), lzssCharacter(5)] | 27 | lzssReference(3, 3), lzssCharacter(5)] |
28 | var chain = lzssChain() | 28 | var chain = lzssChain() |
29 | for node in chainArray: chain.append(node) | 29 | for node in chainArray: chain.append(node) |
30 | check chain.decode() == @[0'u8, 1, 2, 3, 4, 5, 0, 1, 2, 3, 0, 1, 4, 5, 0, 5, 5, 0, 5, 5] | 30 | result = chain |
31 | |||
32 | test "decode": | ||
33 | check chain().decode() == @[0'u8, 1, 2, 3, 4, 5, 0, 1, 2, 3, 0, 1, 4, 5, 0, 5, 5, 0, 5, 5] | ||
34 | |||
35 | test "stats": | ||
36 | let stats = chain().stats() | ||
37 | check stats.characters == newCountTable(concat( | ||
38 | repeat(0'u8, 2), repeat(1'u8, 2), repeat(2'u8, 1), repeat(3'u8, 1), repeat(4'u8, 1), repeat(5'u8, 3))) | ||
39 | check stats.lengths == newCountTable(concat( | ||
40 | repeat(3, 2), repeat(4, 1))) | ||
41 | check stats.positions == newCountTable(concat( | ||
42 | repeat(3, 1), repeat(6, 1), repeat(8, 1))) | ||