/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.soap.internal.rm.store;

import static org.apache.cxf.ws.rm.RMUtils.createReference;
import static org.apache.cxf.ws.rm.RMUtils.getWSRMFactory;
import static org.mule.soap.internal.rm.RMUtils.copyAndClose;
import static org.mule.soap.internal.rm.RMUtils.toByteArray;
import static org.mule.soap.internal.rm.RMUtils.toSequenceAcknowledgement;

import org.mule.soap.api.rm.ReliableMessagingStore;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.List;

import org.apache.cxf.io.CachedOutputStream;
import org.apache.cxf.ws.addressing.EndpointReferenceType;
import org.apache.cxf.ws.rm.DestinationSequence;
import org.apache.cxf.ws.rm.ProtocolVariation;
import org.apache.cxf.ws.rm.SourceSequence;
import org.apache.cxf.ws.rm.persistence.RMMessage;
import org.apache.cxf.ws.rm.persistence.RMStore;
import org.apache.cxf.ws.rm.persistence.RMStoreException;
import org.apache.cxf.ws.rm.v200702.Identifier;
import org.apache.cxf.ws.rm.v200702.SequenceAcknowledgement;

/**
 * Implementation of {@link RMStore} that uses {@link ReliableMessagingStore}.
 *
 * @since 1.6
 */
public class RMStoreImp implements RMStore {

  private final ReliableMessagingStore<Serializable> reliableMessagingStore;

  public RMStoreImp(ReliableMessagingStore<Serializable> reliableMessagingStore) {
    this.reliableMessagingStore = reliableMessagingStore;
  }

  @Override
  public void createSourceSequence(SourceSequence sourceSequence) {
    try {
      SourceSequenceTransfer sourceSequenceTransfer = new SourceSequenceTransfer(sourceSequence);
      reliableMessagingStore.store(sourceSequence.getIdentifier().getValue(), sourceSequenceTransfer);
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public void createDestinationSequence(DestinationSequence destinationSequence) {
    try {
      DestinationSequenceTransfer destinationSequenceTransfer = new DestinationSequenceTransfer(destinationSequence);
      reliableMessagingStore.store(destinationSequence.getIdentifier().getValue(), destinationSequenceTransfer);
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public SourceSequence getSourceSequence(Identifier identifier) {
    SourceSequenceTransfer sourceSequenceTransfer;
    try {
      sourceSequenceTransfer = (SourceSequenceTransfer) reliableMessagingStore.retrieve(identifier.getValue());
      if (sourceSequenceTransfer == null) {
        return null;
      } else {
        return convertToSourceSequence(identifier.getValue(), sourceSequenceTransfer);
      }
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public DestinationSequence getDestinationSequence(Identifier identifier) {
    try {
      DestinationSequenceTransfer destinationSequenceTransfer =
          (DestinationSequenceTransfer) reliableMessagingStore.retrieve(identifier.getValue());
      if (destinationSequenceTransfer == null) {
        return null;
      } else {
        return convertToDestinationSequence(identifier.getValue(), destinationSequenceTransfer);
      }
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public void removeSourceSequence(Identifier identifier) {
    try {
      reliableMessagingStore.remove(identifier.getValue());
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public void removeDestinationSequence(Identifier identifier) {
    try {
      reliableMessagingStore.remove(identifier.getValue());
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public Collection<SourceSequence> getSourceSequences(String s) {
    try {
      Collection<SourceSequence> sourceSequences = new ArrayList<>();
      reliableMessagingStore.retrieveAll().entrySet().stream()
          .filter(entry -> entry.getValue() instanceof SourceSequenceTransfer
              && ((SourceSequenceTransfer) entry.getValue()).getEndpointIdentifier().equals(s))
          .forEach(entry -> sourceSequences
              .add(convertToSourceSequence(entry.getKey(), (SourceSequenceTransfer) entry.getValue())));
      return sourceSequences;
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public Collection<DestinationSequence> getDestinationSequences(String s) {
    try {
      Collection<DestinationSequence> destinationSequences = new ArrayList<>();
      reliableMessagingStore.retrieveAll().entrySet().stream()
          .filter(entry -> entry.getValue() instanceof DestinationSequenceTransfer
              && ((DestinationSequenceTransfer) entry.getValue()).getEndpointIdentifier().equals(s))
          .forEach(entry -> destinationSequences
              .add(convertToDestinationSequence(entry.getKey(), (DestinationSequenceTransfer) entry.getValue())));
      return destinationSequences;
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public Collection<RMMessage> getMessages(Identifier identifier, boolean b) {
    try {
      Collection<RMMessage> rmMessages = new ArrayList<>();

      Serializable serializable = reliableMessagingStore.retrieve(identifier.getValue());
      List<MessageTransfer> messageTransfers;
      if (serializable instanceof SourceSequenceTransfer) {
        messageTransfers = ((SourceSequenceTransfer) serializable).getMessageTransfers();
      } else { // serializable instanceof DestinationSequenceTransfer
        messageTransfers = ((DestinationSequenceTransfer) serializable).getMessageTransfers();
      }

      for (MessageTransfer messageTransfer : messageTransfers) {
        if (messageTransfer.isOutbound() == b) {
          CachedOutputStream stream = new CachedOutputStream();
          stream.write(messageTransfer.getContent());
          copyAndClose(messageTransfer.getContent(), stream);
          stream.flush();

          RMMessage rmMessage = new RMMessage();
          rmMessage.setMessageNumber(messageTransfer.getMessageNumber());
          rmMessage.setTo(messageTransfer.getTo());
          rmMessage.setCreatedTime(messageTransfer.getCreatedTime());
          rmMessage.setContent(stream);
          rmMessage.setContentType(messageTransfer.getContentType());
          rmMessages.add(rmMessage);
        }
      }

      return rmMessages;

    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public void persistOutgoing(SourceSequence sourceSequence, RMMessage rmMessage) {
    try {
      String identifierValue = sourceSequence.getIdentifier().getValue();
      SourceSequenceTransfer sourceSequenceTransfer = (SourceSequenceTransfer) reliableMessagingStore.retrieve(identifierValue);

      long currentMessageNumber = sourceSequence.getCurrentMessageNr();
      boolean isLastMessage = sourceSequence.isLastMessage();

      sourceSequenceTransfer.setLastMessage(isLastMessage);
      sourceSequenceTransfer.setCurrentMessageNumber(currentMessageNumber);
      if (rmMessage != null && rmMessage.getContent() != null) {
        sourceSequenceTransfer.addMessageTransfer(new MessageTransfer(rmMessage, true));
      }

      reliableMessagingStore.update(identifierValue, sourceSequenceTransfer);
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public void persistIncoming(DestinationSequence destinationSequence, RMMessage rmMessage) {
    try {
      String identifierValue = destinationSequence.getIdentifier().getValue();
      DestinationSequenceTransfer destinationSequenceTransfer =
          (DestinationSequenceTransfer) reliableMessagingStore.retrieve(identifierValue);

      long lastMessageNumber = destinationSequence.getLastMessageNumber();
      boolean terminated = destinationSequence.isTerminated();
      byte[] ack = toByteArray(destinationSequence.getAcknowledgment());

      destinationSequenceTransfer.setLastMessageNumber(lastMessageNumber);
      destinationSequenceTransfer.setTerminate(terminated);
      destinationSequenceTransfer.setAcknowledged(ack);
      if (rmMessage != null && rmMessage.getContent() != null) {
        destinationSequenceTransfer.addMessageTransfer(new MessageTransfer(rmMessage, false));
      }

      reliableMessagingStore.update(identifierValue, destinationSequenceTransfer);
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  @Override
  public void removeMessages(Identifier identifier, Collection<Long> messageNrs, boolean b) {
    try {
      String identifierValue = identifier.getValue();
      Serializable serializable = reliableMessagingStore.retrieve(identifierValue);
      List<MessageTransfer> messageTransfers;
      if (serializable instanceof SourceSequenceTransfer) {
        messageTransfers = ((SourceSequenceTransfer) serializable).getMessageTransfers();
        messageTransfers.removeIf((messageTransfer -> messageNrs.contains(messageTransfer.getMessageNumber())));
      } else { // serializable instanceof DestinationSequenceTransfer
        messageTransfers = ((DestinationSequenceTransfer) serializable).getMessageTransfers();
        messageTransfers.removeIf((messageTransfer -> messageNrs.contains(messageTransfer.getMessageNumber())));
      }

      reliableMessagingStore.update(identifierValue, serializable);
    } catch (Exception e) {
      throw new RMStoreException(e);
    }
  }

  private SourceSequence convertToSourceSequence(String identifierValue, SourceSequenceTransfer sourceSequenceTransfer) {
    Identifier identifier = new Identifier();
    identifier.setValue(identifierValue);

    Date expiry = 0 == sourceSequenceTransfer.getExpiry() ? null : new Date(sourceSequenceTransfer.getExpiry());

    String endpointIdentifierValue = sourceSequenceTransfer.getOfferingIdValue();
    Identifier endpointIdentifier = null;
    if (endpointIdentifierValue != null) {
      endpointIdentifier = getWSRMFactory().createIdentifier();
      endpointIdentifier.setValue(endpointIdentifierValue);
    }

    long currentMessageNumber = sourceSequenceTransfer.getCurrentMessageNumber();
    boolean lastMessage = sourceSequenceTransfer.isLastMessage();

    ProtocolVariation protocolVariation = decodeProtocolVersion(sourceSequenceTransfer.getEndpointIdentifier());

    return new SourceSequence(identifier, expiry, endpointIdentifier, currentMessageNumber, lastMessage, protocolVariation);
  }

  private DestinationSequence convertToDestinationSequence(String identifierValue,
                                                           DestinationSequenceTransfer destinationSequenceTransfer) {
    Identifier identifier = new Identifier();
    identifier.setValue(identifierValue);

    EndpointReferenceType acksTo = createReference(destinationSequenceTransfer.getAddressValue());
    long lastMessageNumber = destinationSequenceTransfer.getLastMessageNumber();
    boolean isTerminate = destinationSequenceTransfer.isTerminate();

    SequenceAcknowledgement ack = null;
    byte[] bytes = destinationSequenceTransfer.getAcknowledged();
    if (null != bytes) {
      ack = toSequenceAcknowledgement(bytes);
    }

    ProtocolVariation protocolVariation = decodeProtocolVersion(destinationSequenceTransfer.getProtocolVersion());

    return new DestinationSequence(identifier, acksTo, lastMessageNumber, isTerminate, ack, protocolVariation);
  }

  private ProtocolVariation decodeProtocolVersion(String pv) {
    if (null != pv) {
      int d = pv.indexOf(' ');
      if (d > 0) {
        return ProtocolVariation.findVariant(pv.substring(0, d), pv.substring(d + 1));
      }
    }
    return ProtocolVariation.RM10WSA200408;
  }

}
