Wednesday, May 25, 2011

Threaded IO streams in Java

Solving some interesting problems in the project I'm currently working on I started to analyse following problem: you have a large data set (eg. a file) on which you need to apply some expensive transformation (eg. compression). The key requirement is a performance.

The best performance demands can usually be effectively fulfilled using multi-threaded algorithms, especially on multiple CPU cores hardware. But you may not to have a multi-threaded transformation algorithm, or you can get a requirement to have many algorithms pluggable, not always designed for parallel work.

I started to think that it's easy to divide source data for chunks and start a thread pool to do the job. This is feasible when you can have multiple output streams, but if you want to write everything to a single file may be a little problematic. If you synchronize all threads to write to a single stream, you rather end up with multi-threaded algorithm working as well as simple single-threaded one. You will simply have no benefits from concurrency, when all threads need to wait for currently writing thread, having a lock. They simply don't do the transformation in the time they wait for the lock. In the result - only one thread works effectively.

If you want to have benefits from concurrency you need to design it in the way, that all threads can do their expensive work independently from the others, without waiting and synchronization.

The conception is following. We have N working threads, and each one is working its way and has been started in different time, in proportion to the others. If we had N buffers with previously established size for each thread, we could assert concurrent writing to these buffers from independent threads without any collision. If the buffer is overfilled, we can flush it to the output stream. If the lock is gained only for flushing, we potentially have other threads working in this time on their buffers, without waiting to have resource available. A Java implementation of this issue, conformed to input/output streams contract in Java is simple (I use apache commons to shorten time required to perform obvious things):

import java.io.BufferedOutputStream;
import java.io.DataOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; /** * Wraps target output stream and allows to write to it from multiple threads, handling * underlying data. The data in target output stream is interleaved. It can be then * read then by ThreadedInputStream. * * The stream uses internal buffering and synchronization. The default buffer size is 512KB. * This means, than for each thread there'll be allocated 512KB buffer, when thread starts * to write into the stream. The buffer is flushed into the target output stream periodically. * * Maximum number of threads: Byte.MAX_VALUE * * The lifecycle is following: * <ol> * <li>Create new stream in a main thread * <li>For each worker thread: * <ul> * <li>use the stream in general way * <li>on the end of processing you need to call close(); the stream will not be really closed, * but will be flushed and thread buffers will be removed * </ul> * <li>Call close() in main thread to close stream permanently (as well as the target stream) * </ol> */ public class ThreadedOutputStream extends OutputStream { protected DataOutputStream target; protected int bufSize = 512*1024; // default buffer size = 512 KB protected volatile byte threadsCount = 0; protected Thread creatorThread; /** Internal thread data holder and buffer **/ protected class ThreadStreamHolder { byte index = 0; int size = 0; byte[] buffer = new byte[bufSize]; public ThreadStreamHolder(byte index) { super(); this.index = index; } /** Flush data to the target stream **/ public void flush() throws IOException { if (size>0) { synchronized (target) { target.writeByte(index); // write thread index target.writeInt(size); // write block size target.write(buffer, 0, size); // write data size = 0; } } } public void write(int b) throws IOException { buffer[size++] = (byte) b; if (size>=bufSize) flush(); } } protected ThreadLocal<ThreadStreamHolder> threads = new ThreadLocal<ThreadedOutputStream.ThreadStreamHolder>(); /** * Creates stream using default buffer size (512 KB). * @param target Target output stream where data will be really written. */ public ThreadedOutputStream(OutputStream target) { super(); this.target = new DataOutputStream(target); creatorThread = Thread.currentThread(); } /** * Creates stream using custom buffer size value. * @param target Target output stream where data will be really written. * @param bufSize Buffer size in bytes. */ public ThreadedOutputStream(OutputStream target, int bufSize) { this(target); this.bufSize = bufSize; } @Override public void write(int b) throws IOException { ThreadStreamHolder sh = threads.get(); if (sh==null) { synchronized (this) { // to avoid more threads with the same threads count if (threadsCount==Byte.MAX_VALUE) throw new IOException("Cannot serve for more than Byte.MAX_VALUE threads"); sh = new ThreadStreamHolder(threadsCount++); // passing threadsCount and ++ is not atomic ! threads.set(sh); } } sh.write(b); } @Override public void flush() throws IOException { super.flush(); // flush the buffers on the end ThreadStreamHolder sh = threads.get(); if (sh!=null) sh.flush(); } @Override public void close() throws IOException { flush(); threads.remove(); // in main thread, close stream if (Thread.currentThread().equals(creatorThread)) target.close(); } public static final int TEST_THREADS = 64; // number of threads public static final double TEST_DPT_MAX = 1024*1024*10; // data amount per thread public static final int TEST_BLOCKSIZE = 1024*512; // default block size public static void main(String[] args) throws IOException { File f = new File("threados"); OutputStream target = new BufferedOutputStream(new FileOutputStream(f, false)); final ThreadedOutputStream out = new ThreadedOutputStream(target, TEST_BLOCKSIZE); ThreadGroup group = new ThreadGroup("threados"); // write some data by threads for (int i=0; i<TEST_THREADS; i++) { final int valueToWrite = (i+5)*20; new Thread(group, new Runnable() { @Override public void run() { try { int jMax = (int) (Math.random()*TEST_DPT_MAX) + 1; byte crc = 0; for (int j=0; j<jMax; j++) { out.write(valueToWrite+j); crc+=(valueToWrite+j); } out.write(crc); System.out.println("Some thread count: "+(jMax+1)); out.close(); } catch (IOException e) { e.printStackTrace(); } } }).start(); } // wait for thread group to finish try { synchronized (group) { if (group.activeCount()>0) group.wait(); } } catch (InterruptedException e) { e.printStackTrace(); } out.close(); } }
The static main() method shows an exemplary usage of the stream from multiple threads. You can manipulate constants to check how does it perform.

If we have such multi-threaded stream and the transformation algorithm performance depends on the input data, the buffers are filled in very random way and we have real benefits from multi-threading. I tested this and it performs very well.

But the output file made this way consists of random blocks of interleaved stream data, written to a single file. So how to handle with restoring data from source file (to apply reverse transformation)? This is why we put into the target stream the thread index and block size value before each block. We can now restore this sequentially, reading each thread source stream separately, or we can do this concurrently, using a thread pool with same threads count as in writing mode. All information can be easily retrieved from interleaved stream:

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;

/**
 * Allows to read data previously generated by ThreadedOutputStream. This class instance
 * works in single thread and retrieves data from interleaved stream, written by single 
 * thread. The raw source stream is divided on entries, depending on how many threads
 * were writing to the ThreadedOutputStream.
 */
public class ThreadedInputStream extends InputStream {
        
  protected DataInputStream source;
  protected byte entry;
        
  protected int bytesToRead = 0; // how many bytes can we read already from current data block?
        
  protected long pos = 0;
        
  /**
  * Creates the input stream.
  * 
  * @param source Newly instantiated, clean source input stream to raw data
  *      produced by ThreadedOutputStream.
  * @param entry Entry number. There can be many entries in source stream (0, 1, 2... etc,
  *      depending on count of writing threads).
  * @throws EOFException If there's no such entry yet (you need to manually close source then).
  */
  public ThreadedInputStream(InputStream source, byte entry) throws IOException {
    super();

    if (entry>Byte.MAX_VALUE)
      throw new IOException("Cannot serve for more than Byte.MAX_VALUE threads");
                
    this.source = new DataInputStream(source);
    this.entry = entry;
    lookupNextBlock();
  }
        
  protected void lookupNextBlock() throws IOException {
    while (true) {
      byte currentEntry = source.readByte();
                        
      if (currentEntry==entry) {
        // found next entry datablock
        bytesToRead = source.readInt();
        break;
      } else {
        // found next entry, but for different datablock (look for another)
        int blockSize = source.readInt();
        long toSkip = blockSize;
        while (toSkip>0) {
          long skip = source.skip(toSkip);
          if (skip<0)
            throw new EOFException("Cannot skip full datablock");
          toSkip -= skip;
        }                     
      }
    }
  }

  @Override
  public int read() throws IOException {
    if (bytesToRead<=0)
      try {
        lookupNextBlock();
      } catch (EOFException e) {
        return -1;
      }
                
    bytesToRead--;
    return source.read();
  }

  @Override
  public void close() throws IOException {
    source.close();
  }

  // test
  public static void main(String[] args) throws IOException {
    File f = new File("threados");
                
    ThreadGroup group = new ThreadGroup("threados");
                
    // read some data by threads
    Map<Byte, ByteArrayOutputStream> outmap = new LinkedHashMap<Byte, ByteArrayOutputStream>();
                
    try {
      byte i = 0;
      while (true) {
        InputStream source = new BufferedInputStream(new FileInputStream(f));
        try {
          final ThreadedInputStream is = new ThreadedInputStream(source, i++);
          final ByteArrayOutputStream out = new ByteArrayOutputStream();
          outmap.put(i, out);
                                        
          new Thread(group, new Runnable() {
            @Override
            public void run() {
              try {
                IOUtils.copy(is, out);
                is.close();
              } catch (IOException e) {
                e.printStackTrace();
              }
            }
          }).start();
        } catch (EOFException e) {
          source.close();
          break;
        }
      }
    } catch (EOFException e) {} // no more interleaved streams
                
    // wait for threads
    try {
      synchronized (group) {
        if (group.activeCount()>0)
        group.wait();
      }
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
                
    for (byte b: outmap.keySet()) {
      byte[] ba = outmap.get(b).toByteArray();
      System.out.println(b+" ["+ba.length+"]: "+dumpByteArray(ba));
    }
  }
        
  public static String dumpByteArray(byte[] b) {
    StringBuffer sb = new StringBuffer();
    int i = 0;
    byte crc = 0;
    for (byte b1: b) {
      if (i++<20) {
        if (b1<10 && b1>=0)
          sb.append(0);
        sb.append((int) b1 & 0xFF);
        sb.append(',');
      }
                        
      if (i==b.length) {
        if (crc==b1)
          sb.append("crc ok");
        else
          sb.append("crc error");
      } else
          crc+=b1;
    }
    return sb.toString();
  }

}
Also this time main() method should get the file from output stream test, read it in thread pool, and perform simple checksum check.

Depending on requirements, ThreadedInputStream class can be used as a source for multi-threaded or sequential algorithm as well. You can start N threads to get N-stream to fetch data from it, or you can freely iterate through stream data entries in a single thread (in this case you need to perform N iteration).

I'm presenting this solution as a curiosity, but also as a real, working algorithm I used in the project. It works well for solving problems, when you want to parallelize sequential transformation algorithm (in my case it was compression) on a potentially large input data.  Of course there's another question how to chunk source data, but this will depend on the source data itself. For unstructured file you can for example divide it to equal parts, and start N reading streams with different starting read index.

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.