package com.devexperts.mdd.auth.util;

import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalAmount;
import java.util.Base64;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import static java.util.Objects.requireNonNull;

/**
 * Utility class to generate self-signed tokens. Token format:
 * <pre>
 *   token := encoded-payload "." signature
 *   encoded-payload := BASE64(UTF8(payload))
 *   signature := BASE64(HMAC-SHA256(encoded-payload, UTF8(secret)))
 *   payload := issuer "," subject "," not-before-time "," expiration-time "," issued-at-time "," message
 * </pre>
 * where:
 * <ul>
 *     <li>{@code secret} - key used for signature validation</li>
 *     <li>{@code issuer} - identifies principal that issued this token</li>
 *     <li>{@code subject} - identifies subject of this token</li>
 *     <li>{@code not-before-time} - time after which token will start to be valid</li>
 *     <li>{@code expiration-time} - expiration time of the token, after which token will not be valid</li>
 *     <li>{@code issued-at-time} - time at which this token was issued</li>
 *     <li>{@code message} - any string</li>
 * </ul>
 * All time variables are specified in seconds since the epoch in UTC.
 */
public class SignedToken {

    /** MAC algorithm used for signing is "HMAC-SHA256". */
    public static final String MAC_ALGORITHM = "HmacSHA256";

    /** Charset used for string encoding is "UTF-8". */
    public static final String MAC_CHARSET = "UTF-8";

    static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding();
    static final Base64.Decoder DECODER = Base64.getUrlDecoder();

    private final String issuer;
    private final String subject;
    private final Instant expiration;
    private final Instant notBefore;
    private final Instant issuedAt;
    private final String message;

    private final transient String rawToken;

    protected SignedToken(String issuer, String subject, String message,
        Instant notBefore, Instant expiration, Instant issuedAt, String rawToken)
    {
        this.issuer = validate(issuer, "issuer");
        this.subject = validate(subject, "subject");
        this.message = message;
        this.notBefore = validateTime(notBefore);
        this.expiration = validateTime(requireNonNull(expiration, "expiration"));
        this.issuedAt = validateTime(issuedAt);
        this.rawToken = rawToken;

        if (notBefore != null && notBefore.isAfter(expiration))
            throw new IllegalArgumentException("Not-before time must not be after expiration time");
    }

    public static SignedToken valueOf(String token) {
        requireNonNull(token, "token");

        int separatorIndex = token.indexOf('.');
        if (separatorIndex <= 0)
            throw new IllegalArgumentException("Illegal token: " + token);

        try {
            byte[] payload = DECODER.decode(token.substring(0, separatorIndex));
            byte[] signature = DECODER.decode(token.substring(separatorIndex + 1));

            String payloadString = new String(payload, MAC_CHARSET);
            String[] values = payloadString.split(",", 6);
            if (values.length < 6)
                throw new IllegalArgumentException("Illegal token: " + token);

            String issuer = values[0];
            String subject = values[1];
            Instant nbfTime = parseTime(values[2]);
            Instant expTime = parseTime(values[3]);
            Instant iatTime = parseTime(values[4]);
            String message = values[5];

            return new SignedToken(issuer, subject, message, nbfTime, expTime, iatTime, token);
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    public static Builder newBuilder() {
        return new Builder();
    }

    public static Builder newBuilder(String issuer, String subject, Instant expiration) {
        return new Builder().setIssuer(issuer).setSubject(subject).setExpiration(expiration);
    }

    public static Builder newBuilder(String issuer, String subject, TemporalAmount amount) {
        return new Builder().setIssuer(issuer).setSubject(subject).setExpirationFromNow(amount);
    }

    public String getIssuer() {
        return issuer;
    }

    public String getSubject() {
        return subject;
    }

    public String getMessage() {
        return message;
    }

    public Instant getNotBefore() {
        return notBefore;
    }

    public Instant getIssuedAt() {
        return issuedAt;
    }

    /** Returns expiration time truncated to seconds. */
    public Instant getExpiration() {
        return expiration;
    }

    public boolean isIssuedBefore(SignedToken token) {
        Instant issued = getIssuedAt();
        Instant otherIssued = Objects.requireNonNull(token, "token").getIssuedAt();
        return (issued != null && otherIssued != null && issued.isBefore(otherIssued));
    }

    /**
     * Signs token with the specified secret and returns formatted token string.
     */
    public String signToken(String secret) {
        try {
            byte[] payload = ENCODER.encode(createPayload().getBytes(MAC_CHARSET));
            byte[] signature = ENCODER.encode(computeMac(payload, secret.getBytes(MAC_CHARSET)));
            return new String(payload) + "." + new String(signature);
        } catch (IOException e) {
            throw new IllegalStateException("Error generating token", e);
        }
    }

    /**
     * @deprecated Please use {@link #verifySignature} and {@link #verifyTime} instead
     */
    @Deprecated
    public boolean verifyToken(String secret) {
        return verifyToken(secret, Instant.now());
    }

    /**
     * @deprecated Please use {@link #verifySignature} and {@link #verifyTime} instead
     */
    @Deprecated
    public boolean verifyToken(String secret, Instant now) {
        now = requireNonNull(now, "now").truncatedTo(ChronoUnit.SECONDS);
        if (now.isAfter(expiration) || (notBefore != null && now.isBefore(notBefore)))
            return false;

        return verifySignature(secret);
    }

    /**
     * Check if token is signed with specified secret
     * @param secret a secret to sign token
     */
    public boolean verifySignature(String secret) {
        if (rawToken == null)
            return true;
        return signToken(secret).equals(rawToken);
    }

    /**
     * Check if token is active at specified time
     * @param time time to check if token is active at
    */
    public boolean verifyTime(Instant time) {
        Instant truncatedTime = time.truncatedTo(ChronoUnit.SECONDS);
        return (notBefore == null || !truncatedTime.isBefore(notBefore)) &&
            (expiration == null || !truncatedTime.isAfter(expiration));
    }

    @Override
    public String toString() {
        return "SignedToken{iss=" + issuer + ", sub=" + subject + ", exp=" + expiration
            + ", nbf=" + notBefore + ", iat=" + issuedAt + ", msg=" + message + "}";
    }

    public static class Builder {
        private String issuer;
        private String subject;
        private Instant expiration;
        private Instant notBefore;
        private Instant issued;
        private String message;

        protected Builder() {
        }

        public SignedToken toToken() {
            return new SignedToken(issuer, subject, message, notBefore, expiration, issued, null);
        }

        public String getIssuer() {
            return issuer;
        }

        public Builder setIssuer(String issuer) {
            this.issuer = issuer;
            return this;
        }

        public String getSubject() {
            return subject;
        }

        public Builder setSubject(String subject) {
            this.subject = subject;
            return this;
        }

        public Instant getExpiration() {
            return expiration;
        }

        public Builder setExpiration(Instant expiration) {
            this.expiration = validateTime(expiration);
            return this;
        }

        public Builder setExpirationFromNow(TemporalAmount amount) {
            this.expiration = validateTime(Instant.now().plus(amount));
            return this;
        }

        public Instant getNotBefore() {
            return notBefore;
        }

        public Builder setNotBefore(Instant notBefore) {
            this.notBefore = validateTime(notBefore);
            return this;
        }

        public Instant getIssued() {
            return issued;
        }

        public Builder setIssued(Instant issued) {
            this.issued = validateTime(issued);
            return this;
        }

        public Builder setIssuedNow() {
            this.issued = Instant.now();
            return this;
        }

        public String getMessage() {
            return message;
        }

        public Builder setMessage(String message) {
            this.message = message;
            return this;
        }

        public Builder setUser(String user) {
            setMessage(user);
            return this;
        }

        public Builder setUser(String user, Set<String> feeds) {
            String feedString = feeds.stream().collect(Collectors.joining(";"));
            this.message = user + (feedString.isEmpty() ? "" : "," + feedString);
            return this;
        }
    }

    // Utility methods

    private String createPayload() {
        StringBuilder buff = new StringBuilder();

        buff.append(issuer)
            .append(',')
            .append(subject)
            .append(',');
        if (notBefore != null)
            buff.append(notBefore.getEpochSecond());
        buff.append(',')
            .append(expiration.getEpochSecond())
            .append(',');
        if (issuedAt != null)
            buff.append(issuedAt.getEpochSecond());
        buff.append(',')
            .append(message != null ? message : "");

        return buff.toString();
    }

    private static String validate(String s, String name) {
        requireNonNull(s, name);
        if (s.indexOf(',') >= 0)
            throw new IllegalArgumentException(name + " must not contain commas");
        return s;
    }

    private static Instant validateTime(Instant instant) {
        return (instant != null) ? instant.truncatedTo(ChronoUnit.SECONDS) : null;
    }

    private static Instant parseTime(String s) {
        return (s.isEmpty()) ? null : Instant.ofEpochSecond(Long.valueOf(s));
    }

    private static byte[] computeMac(byte[] payload, byte[] secret) {
        try {
            Mac mac = Mac.getInstance(MAC_ALGORITHM);
            mac.init(new SecretKeySpec(secret, MAC_ALGORITHM));
            return mac.doFinal(payload);
        } catch (InvalidKeyException | NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
    }
}
