package org.springframework.security.oauth2.client.token;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.token.auth.ClientAuthenticationHandler;
import org.springframework.security.oauth2.client.token.auth.DefaultClientAuthenticationHandler;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.http.converter.FormOAuth2AccessTokenMessageConverter;
import org.springframework.security.oauth2.http.converter.FormOAuth2ExceptionHttpMessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.HttpMessageConverterExtractor;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Base support logic for obtaining access tokens.
 *
 * <p>
 * @deprecated See the <a href="https://github.com/spring-projects/spring-security/wiki/OAuth-2.0-Migration-Guide">OAuth 2.0 Migration Guide</a> for Spring Security 5.
 *
 * @author Ryan Heaton
 * @author Dave Syer
 */
@Deprecated
public abstract class OAuth2AccessTokenSupport {

	protected final Log logger = LogFactory.getLog(getClass());

	private static final FormHttpMessageConverter FORM_MESSAGE_CONVERTER = new FormHttpMessageConverter();

	private volatile RestOperations restTemplate;

	private List<HttpMessageConverter<?>> messageConverters;

	private ClientAuthenticationHandler authenticationHandler = new DefaultClientAuthenticationHandler();

	private ResponseErrorHandler responseErrorHandler = new AccessTokenErrorHandler();

	private List<ClientHttpRequestInterceptor> interceptors = new ArrayList<ClientHttpRequestInterceptor>();
	
	private RequestEnhancer tokenRequestEnhancer = new DefaultRequestEnhancer();
	
	/**
	 * Sets the request interceptors that this accessor should use.
	 */
	public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
		this.interceptors = interceptors;
	}
	
	/**
	 * A custom enhancer for the access token request
	 * @param tokenRequestEnhancer
	 */
	public void setTokenRequestEnhancer(RequestEnhancer tokenRequestEnhancer) {
		this.tokenRequestEnhancer = tokenRequestEnhancer;
	}

	private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory() {
		@Override
		protected void prepareConnection(HttpURLConnection connection, String httpMethod)
				throws IOException {
			super.prepareConnection(connection, httpMethod);
			connection.setInstanceFollowRedirects(false);
			connection.setUseCaches(false);
		}
	};

	protected RestOperations getRestTemplate() {
		if (restTemplate == null) {
			synchronized (this) {
				if (restTemplate == null) {
					RestTemplate restTemplate = new RestTemplate();
					restTemplate.setErrorHandler(getResponseErrorHandler());
					restTemplate.setRequestFactory(requestFactory);
					restTemplate.setInterceptors(interceptors);
					this.restTemplate = restTemplate;
				}
			}
		}
		if (messageConverters == null) {
			setMessageConverters(new RestTemplate().getMessageConverters());
		}
		return restTemplate;
	}

	public void setAuthenticationHandler(ClientAuthenticationHandler authenticationHandler) {
		this.authenticationHandler = authenticationHandler;
	}

	public void setMessageConverters(List<HttpMessageConverter<?>> messageConverters) {
		this.messageConverters = new ArrayList<HttpMessageConverter<?>>(messageConverters);
		this.messageConverters.add(new FormOAuth2AccessTokenMessageConverter());
		this.messageConverters.add(new FormOAuth2ExceptionHttpMessageConverter());
	}

	protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource,
			MultiValueMap<String, String> form, HttpHeaders headers) throws OAuth2AccessDeniedException {

		try {
			// Prepare headers and form before going into rest template call in case the URI is affected by the result
			authenticationHandler.authenticateTokenRequest(resource, form, headers);
			// Opportunity to customize form and headers
			tokenRequestEnhancer.enhance(request, resource, form, headers);
			final AccessTokenRequest copy = request;

			final ResponseExtractor<OAuth2AccessToken> delegate = getResponseExtractor();
			ResponseExtractor<OAuth2AccessToken> extractor = new ResponseExtractor<OAuth2AccessToken>() {
				@Override
				public OAuth2AccessToken extractData(ClientHttpResponse response) throws IOException {
					if (response.getHeaders().containsKey("Set-Cookie")) {
						copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
					}
					return delegate.extractData(response);
				}
			};
			return getRestTemplate().execute(getAccessTokenUri(resource, form), getHttpMethod(),
					getRequestCallback(resource, form, headers), extractor , form.toSingleValueMap());

		}
		catch (OAuth2Exception oe) {
			throw new OAuth2AccessDeniedException("Access token denied.", resource, oe);
		}
		catch (RestClientException rce) {
			throw new OAuth2AccessDeniedException("Error requesting access token.", resource, rce);
		}

	}

	protected HttpMethod getHttpMethod() {
		return HttpMethod.POST;
	}

	protected String getAccessTokenUri(OAuth2ProtectedResourceDetails resource, MultiValueMap<String, String> form) {

		String accessTokenUri = resource.getAccessTokenUri();

		if (logger.isDebugEnabled()) {
			logger.debug("Retrieving token from " + accessTokenUri);
		}

		StringBuilder builder = new StringBuilder(accessTokenUri);

		if (getHttpMethod() == HttpMethod.GET) {
			String separator = "?";
			if (accessTokenUri.contains("?")) {
				separator = "&";
			}

			for (String key : form.keySet()) {
				builder.append(separator);
				builder.append(key + "={" + key + "}");
				separator = "&";
			}
		}

		return builder.toString();

	}

	protected ResponseErrorHandler getResponseErrorHandler() {
		return responseErrorHandler;
	}

	/**
	 * Set the request factory that this template uses for obtaining {@link ClientHttpRequest HttpRequests}.
	 */
	public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
		Assert.notNull(requestFactory, "'requestFactory' must not be null");
		this.requestFactory = requestFactory;
	}

	protected ResponseExtractor<OAuth2AccessToken> getResponseExtractor() {
		getRestTemplate(); // force initialization
		return new HttpMessageConverterExtractor<OAuth2AccessToken>(OAuth2AccessToken.class, this.messageConverters);
	}

	protected RequestCallback getRequestCallback(OAuth2ProtectedResourceDetails resource,
			MultiValueMap<String, String> form, HttpHeaders headers) {
		return new OAuth2AuthTokenCallback(form, headers);
	}

	/**
	 * Request callback implementation that writes the given object to the request stream.
	 */
	private class OAuth2AuthTokenCallback implements RequestCallback {

		private final MultiValueMap<String, String> form;

		private final HttpHeaders headers;

		private OAuth2AuthTokenCallback(MultiValueMap<String, String> form, HttpHeaders headers) {
			this.form = form;
			this.headers = headers;
		}

		public void doWithRequest(ClientHttpRequest request) throws IOException {
			request.getHeaders().putAll(this.headers);
			request.getHeaders().setAccept(
					Arrays.asList(MediaType.APPLICATION_JSON, MediaType.APPLICATION_FORM_URLENCODED));
			if (logger.isDebugEnabled()) {
				logger.debug("Encoding and sending form: " + form);
			}
			FORM_MESSAGE_CONVERTER.write(this.form, MediaType.APPLICATION_FORM_URLENCODED, request);
		}
	}

	private class AccessTokenErrorHandler extends DefaultResponseErrorHandler {

		@SuppressWarnings("unchecked")
		@Override
		public void handleError(ClientHttpResponse response) throws IOException {
			for (HttpMessageConverter<?> converter : messageConverters) {
				if (converter.canRead(OAuth2Exception.class, response.getHeaders().getContentType())) {
					OAuth2Exception ex;
					try {
						ex = ((HttpMessageConverter<OAuth2Exception>) converter).read(OAuth2Exception.class, response);
					}
					catch (Exception e) {
						// ignore
						continue;
					}
					throw ex;
				}
			}
			super.handleError(response);
		}

	}

}
