EeBlog(テクニカルブログ)

ハフマン符号化

ランレングス符号化の次はハフマン符号化を実装してみます。

ハフマン符号化は連続する固定長のデータを可変長のデータに変換することで圧縮を行うアルゴリズムです。固定長のデータから可変長のデータへの変換に用いる変換規則は、データの出現率から最適なものを生成します。

出現率の高いデータを短い可変長のデータに、出現率の低いデータを長い可変長データに変換するので、データに偏りがあるほど圧縮率は高まります。

データ全体を調べて変換規則を生成してから圧縮を行う方法を静的ハフマン符号化といいます。最適な変換規則を生成できるので圧縮の効率がよいのですが、ストリームで処理するのに向いていません。

これに対して、データを調べて変換規則を構築しつつ圧縮も同時に行っていく手法を動的ハフマン符号化といいます。これは静的ハフマン符号化に比べて効率は落ちますが、ストリームでの処理が容易です。

以下のサンプルコードでは静的ハフマン符号化を実装しています。ここでは符号化したものをその場で復号化しており、変換規則を流用していますが、ファイルなどに保存する際には変換規則も保存する必要があります。変換規則を保存するとオーバーヘッドが発生するので、ある程度の大木さのデータでないと圧縮の恩恵は得られません。

import java.io.ByteArrayOutputStream;
import java.util.PriorityQueue;

public class HuffmanSample {

    private static abstract class Tree implements Comparable {

	public abstract int getWeight();

	public abstract void fillCode(String prefix, String[] table);

	@Override
	public int compareTo(Tree t) {
	    return getWeight() - t.getWeight();
	}
    }

    private static class Octet extends Tree {

	private int n;
	private int weight;

	public Octet(int n) {
	    this(n, 0);
	}

	public Octet(int n, int weight) {
	    this.n = n;
	    this.weight = weight;
	}

	public void setWeight(int weight) {
	    this.weight = weight;
	}

	@Override
	public int getWeight() {
	    return this.weight;
	}

	@Override
	public void fillCode(String prefix, String[] table) {
	    table[n] = prefix;
	}

	@Override
	public boolean equals(Object obj) {
	    if (!(obj instanceof Octet)) {
		return false;
	    }
	    Octet oct = (Octet) obj;
	    return this.n == oct.n && getWeight() == oct.getWeight();
	}

	@Override
	public String toString() {
	    return n + "#" + getWeight();
	}
    }

    private static class Pair extends Tree {

	private Tree lhs;
	private Tree rhs;

	public Pair(Tree lhs, Tree rhs) {
	    this.lhs = lhs;
	    this.rhs = rhs;
	}

	@Override
	public int getWeight() {
	    return lhs.getWeight() + rhs.getWeight();
	}

	@Override
	public void fillCode(String prefix, String[] table) {
	    lhs.fillCode(prefix + "0", table);
	    rhs.fillCode(prefix + "1", table);
	}

	@Override
	public boolean equals(Object obj) {
	    if (!(obj instanceof Pair)) {
		return false;
	    }
	    Pair pair = (Pair) obj;
	    return lhs.equals(pair.lhs) && rhs.equals(pair.rhs);
	}

	@Override
	public String toString() {
	    return "(" + lhs + " . " + rhs + ")";
	}
    }

    private static int[] makeStats(byte[] data) {
	int[] result = new int[256];
	for (int i = 0; i < data.length; i++) {
	    int octet = (data[i] + 256) % 256;
	    result[octet]++;
	}
	return result;
    }

    public static Tree makeTree(int[] stats) {
	PriorityQueue sortedTrees = new PriorityQueue();
	for (int i = 0; i < stats.length; i++) {
	    Octet oct = new Octet(i);
	    oct.setWeight(stats[i]);
	    sortedTrees.offer(oct);
	}

	while (sortedTrees.size() > 1) {
	    Tree rhs = sortedTrees.poll();
	    Tree lhs = sortedTrees.poll();
	    sortedTrees.offer(new Pair(lhs, rhs));
	}
	return sortedTrees.poll();
    }

    private static String[] makeCodeTable(Tree tree) {
	String[] result = new String[256];
	tree.fillCode("", result);
	return result;
    }

    public static String encode(byte[] data, String[] codeTable) {
	String result = "";
	for (byte datum : data) {
	    result += codeTable[datum];
	}
	return result;
    }

    public static byte[] decode(String data, String[] codeTable) {
	ByteArrayOutputStream baos = new ByteArrayOutputStream();
	OUTER: while (data.length() > 0) {
	    for (int i = 0; i < codeTable.length; i++) {
		if (data.startsWith(codeTable[i])) {
		    baos.write(i);
		    data = data.substring(codeTable[i].length());
		    continue OUTER;
		}
	    }
	    data = data.substring(1);
	}
	return baos.toByteArray();
    }

    public static void main(String[] args) {
	byte[] data = "abbcccddddeeeeeffffff".getBytes();
	String[] codeTable = makeCodeTable(makeTree(makeStats(data)));
	String bits = encode(data, codeTable);

	System.out.println("符号化前のビット数: " + data.length * 8);
	System.out.println("符号化前の文字列: " + new String(data));
	System.out.println();
	System.out.println("符号化後のビット数: " + bits.length());
	System.out.println("復号化した文字列: " + new String(decode(bits, codeTable)));
    }
}