/*
 * All content copyright (c) Terracotta, Inc., except as may otherwise be noted in a separate copyright notice. All
 * rights reserved.
 */
package org.terracotta.cache.serialization;

import com.tc.object.bytecode.NotClearable;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.OutputStream;

public class DsoSerializationStrategy<T> implements SerializationStrategy<T>, NotClearable {

  protected static final byte                 HIGH_BIT = (byte) 0x80;

  protected final ObjectStreamClassSerializer oscSerializer;

  public DsoSerializationStrategy() {
    this(true);
  }

  public DsoSerializationStrategy(boolean internalLocking) {
    oscSerializer = new ObjectStreamClassSerializer(internalLocking);
  }

  public T deserialize(final byte[] data) throws IOException, ClassNotFoundException {
    return (T) new OIS(new ByteArrayInputStream(data), oscSerializer).readObject();
  }

  public T deserialize(final byte[] data, ClassLoader loader) throws IOException, ClassNotFoundException {
    return (T) new OIS(new ByteArrayInputStream(data), oscSerializer, loader).readObject();
  }

  public byte[] serialize(final T value) throws IOException {
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    OOS oos = new OOS(baos, oscSerializer);
    oos.writeObject(value);
    oos.close();
    return baos.toByteArray();
  }

  public String generateStringKeyFor(final Object key) throws IOException {
    if (key instanceof String) { return (String) key; }

    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    ObjectOutputStream oos = new OOS(bos, oscSerializer);

    writeStringKey(key, oos);

    oos.close();

    return bos.toString(0x00);
  }

  protected void writeStringKey(final Object key, final ObjectOutputStream oos) throws IOException {
    oos.writeObject(key);
  }

  void forceSlowLookup() {
    this.oscSerializer.forceSlowLookup();
  }

  protected static class OIS extends ObjectInputStream {

    private final ObjectStreamClassSerializer oscSerializer;
    private final ClassLoader                 loader;

    public OIS(InputStream in, ObjectStreamClassSerializer oscSerializer) throws IOException {
      this(in, oscSerializer, null);
    }

    public OIS(InputStream in, ObjectStreamClassSerializer oscSerializer, ClassLoader loader) throws IOException {
      super(in);
      this.oscSerializer = oscSerializer;
      this.loader = loader;
    }

    @Override
    protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException {
      int code = decodeInt(this);
      return oscSerializer.getObjectStreamClassFor(code, loader);
    }

    @Override
    protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
      if (loader == null) { return super.resolveClass(desc); }
      return Class.forName(desc.getName(), false, loader);
    }
  }

  protected static class OOS extends ObjectOutputStream {

    private final ObjectStreamClassSerializer oscSerializer;

    public OOS(final OutputStream out, final ObjectStreamClassSerializer oscSerializer) throws IOException {
      super(out);
      this.oscSerializer = oscSerializer;
    }

    @Override
    protected void writeClassDescriptor(final ObjectStreamClass desc) throws IOException {
      String name = desc.getName();
      int code = oscSerializer.getMappingFor(name);
      encodeInt(this, code);
    }
  }

  protected static final int decodeInt(final InputStream is) throws IOException {
    int rv = 0;
    int length = is.read();

    if ((length & HIGH_BIT) > 0) {
      length &= ~HIGH_BIT;
      if ((length == 0) || (length > 4)) { throw new IOException("invalid length: " + length);

      }
      for (int i = 0; i < length; i++) {
        int l = is.read() & 0xFF;
        rv |= (l << (8 * ((length - 1) - i)));
      }
      if (rv < 0) { throw new IOException("invalid value: " + rv); }
    } else {
      rv = length & 0xFF;
    }

    return rv;
  }

  protected static final void encodeInt(final OutputStream os, final int value) throws IOException {
    if (value < 0) {
      throw new IOException("cannot encode negative values");
    } else if (value < 0x80) {
      os.write(value);
    } else if (value <= 0xFF) {
      os.write((0x01 | HIGH_BIT));
      os.write(value);
    } else if (value <= 0xFFFF) {
      os.write(0x02 | HIGH_BIT);
      os.write((value >> 8) & 0xFF);
      os.write(value & 0xFF);
    } else if (value <= 0xFFFFFF) {
      os.write(0x03 | HIGH_BIT);
      os.write((value >> 16) & 0xFF);
      os.write((value >> 8) & 0xFF);
      os.write(value & 0xFF);
    } else {
      os.write(0x04 | HIGH_BIT);
      os.write((value >> 24) & 0xFF);
      os.write((value >> 16) & 0xFF);
      os.write((value >> 8) & 0xFF);
      os.write(value & 0xFF);
    }
  }

}
