/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.opensymphony.xwork.util;

import com.opensymphony.xwork.config.Configuration;
import com.opensymphony.xwork.config.ConfigurationManager;
import ognl.MemberAccess;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static java.lang.reflect.Modifier.isStatic;
import static java.text.MessageFormat.format;
import static java.util.Collections.unmodifiableSet;

/**
 * Allows access decisions to be made on the basis of a number of factors.
 */
public class SecurityMemberAccess implements MemberAccess {

    private static final Log LOG = LogFactory.getLog(SecurityMemberAccess.class);

    private Set<String> excludedClasses;
    private Set<String> excludedPackageNames;
    private Set<String> excludedPackageExemptClasses;
    private boolean allowStaticMethodAccess;
    private boolean allowStaticFieldAccess;
    private Set<String> staticMemberAllowedClasses;
    private boolean enforceAllowlistEnabled;
    private Set<String> allowlistClasses;
    private Set<String> allowlistPackageNames;
    private boolean disallowProxyMemberAccess;
    private boolean disallowDefaultPackageAccess;

    public SecurityMemberAccess() {
        loadConfig(ConfigurationManager.getConfiguration());
    }

    public void loadConfig(Configuration configuration) {
        excludedClasses = initExcludedClasses(configuration.getExcludedClasses());
        excludedPackageNames = configuration.getExcludedPackageNames();
        excludedPackageExemptClasses = configuration.getExcludedPackageExemptClasses();
        allowStaticMethodAccess = configuration.isAllowStaticMethodAccess();
        allowStaticFieldAccess = configuration.isAllowStaticFieldAccess();
        disallowProxyMemberAccess = configuration.isDisallowProxyMemberAccess();
        disallowDefaultPackageAccess = configuration.isDisallowDefaultPackageAccess();
        staticMemberAllowedClasses = configuration.getStaticMemberAllowedClasses();
        enforceAllowlistEnabled = configuration.isEnforceAllowlistEnabled();
        allowlistClasses = configuration.getAllowlistClasses();
        allowlistPackageNames = configuration.getAllowlistPackageNames();
    }

    @Override
    public Object setup(Map context, Object target, Member member, String propertyName) {
        Object result = null;
        if (isAccessible(context, target, member, propertyName)) {
            final AccessibleObject accessible = (AccessibleObject) member;
            if (!accessible.isAccessible()) {
                result = Boolean.FALSE;
                accessible.setAccessible(true);
            }
        }
        return result;
    }

    @Override
    public void restore(Map context, Object target, Member member, String propertyName, Object state) {
        if (state == null) {
            return;
        }
        if ((Boolean) state) {
            throw new IllegalArgumentException(format(
                    "Improper restore state [true] for target [{0}], member [{1}], propertyName [{2}]",
                    target,
                    member,
                    propertyName));
        }
        ((AccessibleObject) member).setAccessible(false);
    }

    @Override
    public boolean isAccessible(Map context, Object target, Member member, String propertyName) {
        if (LOG.isDebugEnabled()) {
            LOG.debug(format("Checking access for [target: {0}, member: {1}, property: {2}]", target, member, propertyName));
        }

        final int memberModifiers = member.getModifiers();
        final Class<?> memberClass = member.getDeclaringClass();
        final Class<?> targetClass = isStatic(memberModifiers) ? memberClass : target.getClass();
        if (!memberClass.isAssignableFrom(targetClass)) {
            throw new IllegalArgumentException("Target does not match member!");
        }

        if (disallowProxyMemberAccess && ProxyUtil.isProxyMember(member, target)) {
            LOG.warn(format("Access to proxy is blocked! Target class [{0}] of target [{1}], member [{2}]", targetClass, target, member));
            return false;
        }

        if (!checkPublicMemberAccess(memberModifiers)) {
            LOG.warn(format("Access to non-public [{0}] is blocked!", member));
            return false;
        }

        if (!checkStaticFieldAccess(member)) {
            LOG.warn(format("Access to static field [{0}] is blocked!", member));
            return false;
        }

        if (!checkStaticMethodAccess(member)) {
            LOG.warn(format("Access to static method [{0}] is blocked!", member));
            return false;
        }

        if (isClassExcluded(memberClass)) {
            LOG.warn(format("Declaring class of member type [{0}] is excluded!", member));
            return false;
        }

        if (targetClass != memberClass && isClassExcluded(targetClass)) {
            LOG.warn(format("Target class [{0}] of target [{1}] is excluded!", targetClass, target));
            return false;
        }

        if (disallowDefaultPackageAccess) {
            if (targetClass.getPackage() == null || targetClass.getPackage().getName().isEmpty()) {
                LOG.warn(format("Class [{0}] from the default package is excluded!", targetClass));
                return false;
            }
            if (memberClass.getPackage() == null || memberClass.getPackage().getName().isEmpty()) {
                LOG.warn(format("Class [{0}] from the default package is excluded!", memberClass));
                return false;
            }
        }

        if (isPackageExcluded(targetClass)) {
            LOG.warn(format("Package [{0}] of target class [{1}] of target [{2}] is excluded!",
                    targetClass.getPackage(),
                    targetClass,
                    target));
            return false;
        }

        if (targetClass != memberClass && isPackageExcluded(memberClass)) {
            LOG.warn(format("Package [{0}] of member [{1}] are excluded!", memberClass.getPackage(), member));
            return false;
        }

        if (enforceAllowlistEnabled) {
            if (!isClassAllowlisted(targetClass)) {
                LOG.warn(format("Target class [{0}] of target [{1}] is not allowlisted!", targetClass, target));
                return false;
            }
            if (targetClass != memberClass && !isClassAllowlisted(memberClass)) {
                LOG.warn(format("Declaring class [{0}] of member [{1}] is not allowlisted!", memberClass, member));
                return false;
            }
        }

        return true;
    }

    private boolean isClassAllowlisted(Class<?> clazz) {
        return allowlistClasses.contains(clazz.getName()) || isClassBelongsToPackages(clazz, allowlistPackageNames);
    }

    /**
     * Check access for static method (via modifiers).
     *
     * Note: For non-static members, the result is always true.
     *
     * @param member
     * @return
     */
    protected boolean checkStaticMethodAccess(Member member) {
        if (checkEnumAccess(member)) {
            if (LOG.isDebugEnabled()) {
                LOG.debug(format("Allowing access to Enum#values() of class [{0}]", member.getDeclaringClass()));
            }
            return true;
        }
        if (!(member instanceof Field) && isStatic(member.getModifiers())) {
            if (!allowStaticMethodAccess) {
                return false;
            }
            return staticMemberAllowedClasses.contains(member.getDeclaringClass().getName());
        }
        return true;
    }

    /**
     * Check access for static field (via modifiers).
     * <p>
     * Note: For non-static members, the result is always true.
     *
     * @param member
     * @return
     */
    protected boolean checkStaticFieldAccess(Member member) {
        if (member instanceof Field && isStatic(member.getModifiers())) {
            if (!allowStaticFieldAccess) {
                return false;
            }
            return staticMemberAllowedClasses.contains(member.getDeclaringClass().getName());
        }
        return true;
    }

    /**
     * Check access for public members (via modifiers)
     * <p>
     * Returns true if-and-only-if the member is public.
     *
     * @param memberModifiers
     * @return
     */
    protected boolean checkPublicMemberAccess(int memberModifiers) {
        return Modifier.isPublic(memberModifiers);
    }

    protected boolean checkEnumAccess(Member member) {
        return member.getDeclaringClass().isEnum()
                && isStatic(member.getModifiers())
                && member instanceof Method
                && member.getName().equals("values")
                && ((Method) member).getParameterCount() == 0;
    }

    protected boolean isPackageExcluded(Class<?> clazz) {
        return !excludedPackageExemptClasses.contains(clazz.getName()) && isExcludedPackageNames(clazz);
    }

    public static String toPackageName(Class<?> clazz) {
        if (clazz.getPackage() == null) {
            return "";
        }
        return clazz.getPackage().getName();
    }

    protected boolean isExcludedPackageNames(Class<?> clazz) {
        return isClassBelongsToPackages(clazz, excludedPackageNames);
    }

    public static boolean isClassBelongsToPackages(Class<?> clazz, Set<String> matchingPackages) {
        List<String> packageParts = Arrays.asList(toPackageName(clazz).split("\\."));
        for (int i = 0; i < packageParts.size(); i++) {
            String parentPackage = String.join(".", packageParts.subList(0, i + 1));
            if (matchingPackages.contains(parentPackage)) {
                return true;
            }
        }
        return false;
    }

    protected boolean isClassExcluded(Class<?> clazz) {
        return excludedClasses.contains(clazz.getName());
    }

    private static Set<String> initExcludedClasses(Set<String> excludedClasses) {
        Set<String> newExcludedClasses = new HashSet<>(excludedClasses);
        newExcludedClasses.add(Object.class.getName());
        newExcludedClasses.add(Class.class.getName());
        return unmodifiableSet(newExcludedClasses);
    }
}
