/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.sdjwt;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.time.Instant;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.jboss.logging.Logger;
import org.keycloak.common.VerificationException;
import org.keycloak.crypto.SignatureVerifierContext;
import org.keycloak.sdjwt.IssuerSignedJWT;
import org.keycloak.sdjwt.IssuerSignedJwtVerificationOpts;
import org.keycloak.sdjwt.JwkParsingUtils;
import org.keycloak.sdjwt.SdJwtUtils;
import org.keycloak.sdjwt.consumer.PresentationRequirements;
import org.keycloak.sdjwt.vp.KeyBindingJWT;
import org.keycloak.sdjwt.vp.KeyBindingJwtVerificationOpts;

public class SdJwtVerificationContext {
    private static final Logger logger = Logger.getLogger((String)SdJwtVerificationContext.class.getName());
    private String sdJwtVpString;
    private final IssuerSignedJWT issuerSignedJwt;
    private final Map<String, String> disclosures;
    private KeyBindingJWT keyBindingJwt;

    public SdJwtVerificationContext(String sdJwtVpString, IssuerSignedJWT issuerSignedJwt, Map<String, String> disclosures, KeyBindingJWT keyBindingJwt) {
        this(issuerSignedJwt, disclosures);
        this.keyBindingJwt = keyBindingJwt;
        this.sdJwtVpString = sdJwtVpString;
    }

    public SdJwtVerificationContext(IssuerSignedJWT issuerSignedJwt, Map<String, String> disclosures) {
        this.issuerSignedJwt = issuerSignedJwt;
        this.disclosures = disclosures;
    }

    public SdJwtVerificationContext(IssuerSignedJWT issuerSignedJwt, List<String> disclosureStrings) {
        this.issuerSignedJwt = issuerSignedJwt;
        this.disclosures = this.computeDigestDisclosureMap(disclosureStrings);
    }

    private Map<String, String> computeDigestDisclosureMap(List<String> disclosureStrings) {
        return disclosureStrings.stream().map(disclosureString -> {
            String digest = SdJwtUtils.hashAndBase64EncodeNoPad(disclosureString.getBytes(), this.issuerSignedJwt.getSdHashAlg());
            return new AbstractMap.SimpleEntry<String, String>(digest, (String)disclosureString);
        }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    public void verifyIssuance(List<SignatureVerifierContext> issuerVerifyingKeys, IssuerSignedJwtVerificationOpts issuerSignedJwtVerificationOpts, PresentationRequirements presentationRequirements) throws VerificationException {
        this.validateIssuerSignedJwt(issuerVerifyingKeys);
        JsonNode disclosedPayload = this.validateDisclosuresDigests();
        this.validateIssuerSignedJwtTimeClaims(disclosedPayload, issuerSignedJwtVerificationOpts);
        if (presentationRequirements != null) {
            presentationRequirements.checkIfSatisfiedBy(disclosedPayload);
        }
    }

    public void verifyPresentation(List<SignatureVerifierContext> issuerVerifyingKeys, IssuerSignedJwtVerificationOpts issuerSignedJwtVerificationOpts, KeyBindingJwtVerificationOpts keyBindingJwtVerificationOpts, PresentationRequirements presentationRequirements) throws VerificationException {
        if (keyBindingJwtVerificationOpts.isKeyBindingRequired() && this.keyBindingJwt == null) {
            throw new VerificationException("Missing Key Binding JWT");
        }
        this.verifyIssuance(issuerVerifyingKeys, issuerSignedJwtVerificationOpts, presentationRequirements);
        if (keyBindingJwtVerificationOpts.isKeyBindingRequired()) {
            this.validateKeyBindingJwt(keyBindingJwtVerificationOpts);
        }
    }

    private void validateIssuerSignedJwt(List<SignatureVerifierContext> verifiers) throws VerificationException {
        this.issuerSignedJwt.verifySdHashAlgorithm();
        Iterator<SignatureVerifierContext> iterator = verifiers.iterator();
        while (iterator.hasNext()) {
            try {
                SignatureVerifierContext verifier = iterator.next();
                this.issuerSignedJwt.verifySignature(verifier);
                return;
            }
            catch (VerificationException e) {
                logger.debugf((Throwable)e, "Issuer-signed JWT's signature verification failed against one potential verifying key", new Object[0]);
                if (!iterator.hasNext()) continue;
                logger.debugf("Retrying Issuer-signed JWT's signature verification with next potential verifying key", new Object[0]);
            }
        }
        throw new VerificationException("Invalid Issuer-Signed JWT: Signature could not be verified");
    }

    private void validateKeyBindingJwt(KeyBindingJwtVerificationOpts keyBindingJwtVerificationOpts) throws VerificationException {
        this.validateKeyBindingJwtTyp();
        JsonNode cnf = this.issuerSignedJwt.getCnfClaim().orElseThrow(() -> new VerificationException("No cnf claim in Issuer-signed JWT for key binding"));
        SignatureVerifierContext holderVerifier = this.buildHolderVerifier(cnf);
        try {
            this.keyBindingJwt.verifySignature(holderVerifier);
        }
        catch (VerificationException e) {
            throw new VerificationException("Key binding JWT invalid", (Throwable)e);
        }
        this.validateKeyBindingJwtTimeClaims(keyBindingJwtVerificationOpts);
        this.preventKeyBindingJwtReplay(keyBindingJwtVerificationOpts);
        this.validateKeyBindingJwtSdHashIntegrity();
    }

    private void validateKeyBindingJwtTyp() throws VerificationException {
        String typ = this.keyBindingJwt.getHeader().getType();
        if (!typ.equals("kb+jwt")) {
            throw new VerificationException("Key Binding JWT is not of declared typ kb+jwt");
        }
    }

    private SignatureVerifierContext buildHolderVerifier(JsonNode cnf) throws VerificationException {
        Objects.requireNonNull(cnf);
        JsonNode cnfJwk = cnf.get("jwk");
        if (cnfJwk == null) {
            throw new UnsupportedOperationException("Only cnf/jwk claim supported");
        }
        try {
            return JwkParsingUtils.convertJwkNodeToVerifierContext(cnfJwk);
        }
        catch (Exception e) {
            throw new VerificationException("Could not process cnf/jwk", (Throwable)e);
        }
    }

    private void validateIssuerSignedJwtTimeClaims(JsonNode payload, IssuerSignedJwtVerificationOpts issuerSignedJwtVerificationOpts) throws VerificationException {
        long now = Instant.now().getEpochSecond();
        try {
            if (issuerSignedJwtVerificationOpts.mustValidateIssuedAtClaim() && now < SdJwtUtils.readTimeClaim(payload, "iat")) {
                throw new VerificationException("JWT issued in the future");
            }
        }
        catch (VerificationException e) {
            throw new VerificationException("Issuer-Signed JWT: Invalid `iat` claim", (Throwable)e);
        }
        try {
            if (issuerSignedJwtVerificationOpts.mustValidateExpirationClaim() && now >= SdJwtUtils.readTimeClaim(payload, "exp")) {
                throw new VerificationException("JWT has expired");
            }
        }
        catch (VerificationException e) {
            throw new VerificationException("Issuer-Signed JWT: Invalid `exp` claim", (Throwable)e);
        }
        try {
            if (issuerSignedJwtVerificationOpts.mustValidateNotBeforeClaim() && now < SdJwtUtils.readTimeClaim(payload, "nbf")) {
                throw new VerificationException("JWT is not yet valid");
            }
        }
        catch (VerificationException e) {
            throw new VerificationException("Issuer-Signed JWT: Invalid `nbf` claim", (Throwable)e);
        }
    }

    private void validateKeyBindingJwtTimeClaims(KeyBindingJwtVerificationOpts keyBindingJwtVerificationOpts) throws VerificationException {
        try {
            this.keyBindingJwt.verifyIssuedAtClaim();
        }
        catch (VerificationException e) {
            throw new VerificationException("Key binding JWT: Invalid `iat` claim", (Throwable)e);
        }
        try {
            this.keyBindingJwt.verifyAge(keyBindingJwtVerificationOpts.getAllowedMaxAge());
        }
        catch (VerificationException e) {
            throw new VerificationException("Key binding JWT is too old");
        }
        try {
            if (keyBindingJwtVerificationOpts.mustValidateExpirationClaim()) {
                this.keyBindingJwt.verifyExpClaim();
            }
        }
        catch (VerificationException e) {
            throw new VerificationException("Key binding JWT: Invalid `exp` claim", (Throwable)e);
        }
        try {
            if (keyBindingJwtVerificationOpts.mustValidateNotBeforeClaim()) {
                this.keyBindingJwt.verifyNotBeforeClaim();
            }
        }
        catch (VerificationException e) {
            throw new VerificationException("Key binding JWT: Invalid `nbf` claim", (Throwable)e);
        }
    }

    private JsonNode validateDisclosuresDigests() throws VerificationException {
        HashSet<String> visitedSalts = new HashSet<String>();
        HashSet<String> visitedDigests = new HashSet<String>();
        HashSet<String> visitedDisclosureStrings = new HashSet<String>();
        JsonNode disclosedPayload = this.validateViaRecursiveDisclosing(SdJwtUtils.deepClone(this.issuerSignedJwt.getPayload()), visitedSalts, visitedDigests, visitedDisclosureStrings);
        this.validateDisclosuresVisits(visitedDisclosureStrings);
        return disclosedPayload;
    }

    private JsonNode validateViaRecursiveDisclosing(JsonNode currentNode, Set<String> visitedSalts, Set<String> visitedDigests, Set<String> visitedDisclosureStrings) throws VerificationException {
        if (!currentNode.isObject() && !currentNode.isArray()) {
            return currentNode;
        }
        if (currentNode.isObject()) {
            ObjectNode currentObjectNode = (ObjectNode)currentNode;
            JsonNode sdArray = currentObjectNode.get("_sd");
            if (sdArray != null && sdArray.isArray()) {
                for (JsonNode el : sdArray) {
                    if (!el.isTextual()) {
                        throw new VerificationException("Unexpected non-string element inside _sd array: " + el);
                    }
                    String digest = el.asText();
                    this.markDigestAsVisited(digest, visitedDigests);
                    String disclosure = this.disclosures.get(digest);
                    if (disclosure == null) continue;
                    visitedDisclosureStrings.add(disclosure);
                    DisclosureFields decodedDisclosure = this.validateSdArrayDigestDisclosureFormat(disclosure);
                    this.markSaltAsVisited(decodedDisclosure.getSaltValue(), visitedSalts);
                    currentObjectNode.set(decodedDisclosure.getClaimName(), decodedDisclosure.getClaimValue());
                }
            }
            currentObjectNode.remove("_sd");
            currentObjectNode.remove("_sd_alg");
        }
        if (currentNode.isArray()) {
            ArrayNode currentArrayNode = (ArrayNode)currentNode;
            ArrayList<Integer> indexesToRemove = new ArrayList<Integer>();
            for (int i = 0; i < currentArrayNode.size(); ++i) {
                Map.Entry field;
                JsonNode itemNode = currentArrayNode.get(i);
                if (!itemNode.isObject() || itemNode.size() != 1 || !((String)(field = (Map.Entry)itemNode.fields().next()).getKey()).equals("...") || !((JsonNode)field.getValue()).isTextual()) continue;
                String digest = ((JsonNode)field.getValue()).asText();
                this.markDigestAsVisited(digest, visitedDigests);
                String disclosure = this.disclosures.get(digest);
                if (disclosure != null) {
                    visitedDisclosureStrings.add(disclosure);
                    DisclosureFields decodedDisclosure = this.validateArrayElementDigestDisclosureFormat(disclosure);
                    this.markSaltAsVisited(decodedDisclosure.getSaltValue(), visitedSalts);
                    currentArrayNode.set(i, decodedDisclosure.getClaimValue());
                    continue;
                }
                indexesToRemove.add(i);
            }
            indexesToRemove.forEach(arg_0 -> ((ArrayNode)currentArrayNode).remove(arg_0));
        }
        for (JsonNode childNode : currentNode) {
            this.validateViaRecursiveDisclosing(childNode, visitedSalts, visitedDigests, visitedDisclosureStrings);
        }
        return currentNode;
    }

    private void markDigestAsVisited(String digest, Set<String> visitedDigests) throws VerificationException {
        if (!visitedDigests.add(digest)) {
            throw new VerificationException("A digest was encountered more than once: " + digest);
        }
    }

    private void markSaltAsVisited(String salt, Set<String> visitedSalts) throws VerificationException {
        if (!visitedSalts.add(salt)) {
            throw new VerificationException("A salt value was reused: " + salt);
        }
    }

    private DisclosureFields validateSdArrayDigestDisclosureFormat(String disclosure) throws VerificationException {
        String claimName;
        ArrayNode arrayNode = SdJwtUtils.decodeDisclosureString(disclosure);
        if (arrayNode.size() != 3) {
            throw new VerificationException("A field disclosure must contain exactly three elements");
        }
        List<String> denylist = Arrays.asList("_sd", "...");
        if (denylist.contains(claimName = arrayNode.get(1).asText())) {
            throw new VerificationException("Disclosure claim name must not be '_sd' or '...'");
        }
        return new DisclosureFields(arrayNode.get(0).asText(), claimName, arrayNode.get(2));
    }

    private DisclosureFields validateArrayElementDigestDisclosureFormat(String disclosure) throws VerificationException {
        ArrayNode arrayNode = SdJwtUtils.decodeDisclosureString(disclosure);
        if (arrayNode.size() != 2) {
            throw new VerificationException("An array element disclosure must contain exactly two elements");
        }
        return new DisclosureFields(arrayNode.get(0).asText(), null, arrayNode.get(1));
    }

    private void validateDisclosuresVisits(Set<String> visitedDisclosureStrings) throws VerificationException {
        if (visitedDisclosureStrings.size() < this.disclosures.size()) {
            throw new VerificationException("At least one disclosure is not protected by digest");
        }
    }

    private void preventKeyBindingJwtReplay(KeyBindingJwtVerificationOpts keyBindingJwtVerificationOpts) throws VerificationException {
        JsonNode nonce = this.keyBindingJwt.getPayload().get("nonce");
        if (nonce == null || !nonce.isTextual() || !nonce.asText().equals(keyBindingJwtVerificationOpts.getNonce())) {
            throw new VerificationException("Key binding JWT: Unexpected `nonce` value");
        }
        JsonNode aud = this.keyBindingJwt.getPayload().get("aud");
        if (aud == null || !aud.isTextual() || !aud.asText().equals(keyBindingJwtVerificationOpts.getAud())) {
            throw new VerificationException("Key binding JWT: Unexpected `aud` value");
        }
    }

    private void validateKeyBindingJwtSdHashIntegrity() throws VerificationException {
        Objects.requireNonNull(this.sdJwtVpString);
        JsonNode sdHash = this.keyBindingJwt.getPayload().get("sd_hash");
        if (sdHash == null || !sdHash.isTextual()) {
            throw new VerificationException("Key binding JWT: Claim `sd_hash` missing or not a string");
        }
        int lastDelimiterIndex = this.sdJwtVpString.lastIndexOf("~");
        String toHash = this.sdJwtVpString.substring(0, lastDelimiterIndex + 1);
        String digest = SdJwtUtils.hashAndBase64EncodeNoPad(toHash.getBytes(), this.issuerSignedJwt.getSdHashAlg());
        if (!digest.equals(sdHash.asText())) {
            throw new VerificationException("Key binding JWT: Invalid `sd_hash` digest");
        }
    }

    private static class DisclosureFields {
        String saltValue;
        String claimName;
        JsonNode claimValue;

        public DisclosureFields(String saltValue, String claimName, JsonNode claimValue) {
            this.saltValue = saltValue;
            this.claimName = claimName;
            this.claimValue = claimValue;
        }

        public String getSaltValue() {
            return this.saltValue;
        }

        public String getClaimName() {
            return this.claimName;
        }

        public JsonNode getClaimValue() {
            return this.claimValue;
        }
    }
}

