package com.instabug.library.networkinterception.urlconnection

import android.os.Build
import androidx.annotation.RequiresApi
import com.instabug.library.diagnostics.IBGDiagnostics
import com.instabug.library.networkinterception.NetworkLogProcessor
import com.instabug.library.networkinterception.delegate.NetworkInterceptorDelegate
import com.instabug.library.networkinterception.model.NetworkClient
import com.instabug.library.networkinterception.model.NetworkLogModel
import com.instabug.library.networkinterception.model.NetworkLogRequestModel
import com.instabug.library.networkinterception.model.NetworkLogResponseModel
import com.instabug.library.util.TimeUtils
import com.instabug.library.util.extenstions.alsoSafely
import com.instabug.library.util.extenstions.runOrLogError
import java.io.InputStream
import java.io.OutputStream
import java.lang.ref.WeakReference
import java.net.HttpURLConnection
import java.net.URL
import java.security.Permission
import java.util.concurrent.TimeUnit

open class HttpUrlConnectionInterceptor<T : HttpURLConnection>(
    val url: URL,
    protected val connectionDelegate: T,
    private val interceptionDelegate: NetworkInterceptorDelegate<OutputStream, InputStream>,
    private val processor: NetworkLogProcessor
) {

    private val logBuilder: NetworkLogModel.Builder
    private val requestBuilder: NetworkLogRequestModel.Builder = NetworkLogRequestModel.Builder()
    private val responseBuilder: NetworkLogResponseModel.Builder = NetworkLogResponseModel.Builder()
    private val requestHeaders: MutableMap<String, String> = HashMap()
    private val responseHeaders: MutableMap<String, String> = HashMap()
    private var ended: Boolean = false
    private var started: Boolean = false
    private var startTimeNano: Long? = null

    private var delegateInputStream: InputStream? = null
    private var inputStreamDecorator: InputStream? = null
    private var delegateOutputStream: OutputStream? = null
    private var outputStreamDecorator: OutputStream? = null

    init {
        logBuilder =
            NetworkLogModel
                .Builder(NetworkClient.URL_CONNECTION, url.toString())
                .setRequestBuilder(requestBuilder.setHeaders(requestHeaders))
                .setResponseBuilder(responseBuilder.setHeaders(responseHeaders))
        invalidateRequestHeaders()
    }

    override fun toString(): String = connectionDelegate.toString()

    fun connect() = connectionInitiatingCallback { connect() }

    fun setConnectTimeout(timeout: Int) {
        connectionDelegate.connectTimeout = timeout
    }

    fun getConnectTimeout(): Int = connectionDelegate.connectTimeout

    fun setReadTimeout(timeout: Int) {
        connectionDelegate.readTimeout = timeout
    }

    fun getReadTimeout(): Int = connectionDelegate.readTimeout

    fun getURL(): URL? = connectionDelegate.url

    fun getContentLength(): Int =
        connectionInitiatingCallback { connectionDelegate.getContentLength() }
            .alsoSafely { runIfCanModify { responseBuilder.setContentLength(it.toLong()) } }

    @RequiresApi(Build.VERSION_CODES.N)
    fun getContentLengthLong(): Long =
        connectionInitiatingCallback { connectionDelegate.contentLengthLong }
            .alsoSafely { runIfCanModify { responseBuilder.setContentLength(it) } }

    fun getContentType(): String? = connectionInitiatingCallback { connectionDelegate.contentType }

    fun getContentEncoding(): String? =
        connectionInitiatingCallback { connectionDelegate.contentEncoding }

    fun getExpiration(): Long = connectionInitiatingCallback { connectionDelegate.expiration }

    fun getDate(): Long = connectionInitiatingCallback { connectionDelegate.date }

    fun getLastModified(): Long = connectionInitiatingCallback { connectionDelegate.lastModified }

    fun getHeaderField(n: Int): String? =
        connectionInitiatingCallback { connectionDelegate.getHeaderField(n) }

    fun getHeaderField(name: String?): String? =
        connectionInitiatingCallback { connectionDelegate.getHeaderField(name) }
            ?.alsoSafely { value ->
                runIfCanModify { name?.let { key -> responseHeaders[key] = value } }
            }

    fun getHeaderFields(): Map<String?, List<String?>?>? =
        connectionInitiatingCallback { connectionDelegate.headerFields }
            ?.alsoSafely { headers ->
                runIfCanModify {
                    headers.forEach { e ->
                        e.key?.let { key ->
                            e.joinHeaderValues()?.let { value -> responseHeaders[key] = value }
                        }
                    }
                }
            }

    fun getHeaderFieldInt(name: String?, default: Int): Int =
        connectionInitiatingCallback { connectionDelegate.getHeaderFieldInt(name, default) }
            .alsoSafely { value ->
                runIfCanModify {
                    name?.let { key -> responseHeaders[key] = value.toString() }
                }
            }

    @RequiresApi(Build.VERSION_CODES.N)
    fun getHeaderFieldLong(name: String?, default: Long): Long =
        connectionInitiatingCallback { connectionDelegate.getHeaderFieldLong(name, default) }
            .alsoSafely { value ->
                runIfCanModify {
                    name?.let { key -> responseHeaders[key] = value.toString() }
                }
            }

    fun getHeaderFieldDate(name: String?, default: Long): Long =
        connectionInitiatingCallback { connectionDelegate.getHeaderFieldDate(name, default) }
            .alsoSafely { value ->
                runIfCanModify {
                    name?.let { key -> responseHeaders[key] = value.toString() }
                }
            }

    fun getHeaderFieldKey(n: Int): String? =
        connectionInitiatingCallback { connectionDelegate.getHeaderFieldKey(n) }

    fun getContent(): Any? = connectionInitiatingCallback { connectionDelegate.getContent() }
        ?.extractBody()

    fun getContent(classes: Array<out Class<*>>?): Any? =
        connectionInitiatingCallback { connectionDelegate.getContent(classes) }?.extractBody()

    fun getPermission(): Permission? =
        connectionTerminationOnExceptionCallback { connectionDelegate.getPermission() }

    fun getInputStream(): InputStream? =
        connectionInitiatingCallback { connectionDelegate.inputStream }?.extractInputStreamBody()

    fun getOutputStream(): OutputStream? =
        connectionTerminationOnExceptionCallback { connectionDelegate.outputStream }?.extractOutputStreamBody()

    fun setDoInput(doInput: Boolean) {
        connectionDelegate.doInput = doInput
    }

    fun getDoInput(): Boolean = connectionDelegate.doInput

    fun setDoOutput(doOutput: Boolean) {
        connectionDelegate.doOutput = doOutput
    }

    fun getDoOutput(): Boolean = connectionDelegate.doOutput

    fun setAllowUserInteraction(allowUserInteraction: Boolean) {
        connectionDelegate.allowUserInteraction = allowUserInteraction
    }

    fun getAllowUserInteraction(): Boolean = connectionDelegate.allowUserInteraction

    fun setUseCaches(useCaches: Boolean) {
        connectionDelegate.useCaches = useCaches
    }

    fun getUseCaches(): Boolean = connectionDelegate.useCaches

    fun setIfModifiedSince(ifModifiedSince: Long) {
        connectionDelegate.ifModifiedSince = ifModifiedSince
    }

    fun getIfModifiedSince(): Long = connectionDelegate.getIfModifiedSince()

    fun setDefaultUseCaches(defaultUseCaches: Boolean) {
        connectionDelegate.defaultUseCaches = defaultUseCaches
    }

    fun getDefaultUseCaches(): Boolean = connectionDelegate.defaultUseCaches

    fun setRequestProperty(key: String?, value: String?) {
        runSafelyIfCanModify { key?.let { k -> value?.let { v -> requestHeaders[k] = v } } }
        connectionDelegate.setRequestProperty(key, value)
    }

    fun addRequestProperty(key: String?, value: String?) {
        runSafelyIfCanModify {
            key?.let { k ->
                value?.let { v ->
                    requestHeaders.takeUnless { it.containsKey(k) }?.put(k, v)
                }
            }
        }
        connectionDelegate.addRequestProperty(key, value)
    }

    fun getRequestProperty(key: String?): String? = connectionDelegate.getRequestProperty(key)

    fun getRequestProperties(): Map<String?, List<String?>?>? =
        connectionDelegate.getRequestProperties()

    fun setFixedLengthStreamingMode(contentLength: Int) =
        connectionDelegate.setFixedLengthStreamingMode(contentLength)

    fun setFixedLengthStreamingMode(contentLength: Long) =
        connectionDelegate.setFixedLengthStreamingMode(contentLength)

    fun setChunkedStreamingMode(chunkLen: Int) =
        connectionDelegate.setChunkedStreamingMode(chunkLen)

    fun setInstanceFollowRedirects(followRedirects: Boolean) {
        connectionDelegate.instanceFollowRedirects = followRedirects
    }

    fun getInstanceFollowRedirects(): Boolean = connectionDelegate.instanceFollowRedirects

    fun setRequestMethod(method: String?) {
        runSafelyIfCanModify { requestBuilder.setMethod(method) }
        connectionDelegate.requestMethod = method
    }

    fun getRequestMethod(): String? = connectionDelegate.requestMethod?.alsoSafely {
        runIfCanModify { requestBuilder.setMethod(it) }
    }

    fun getResponseCode(): Int =
        connectionInitiatingCallback { connectionDelegate.getResponseCode() }
            .alsoSafely { runIfCanModify { responseBuilder.setResponseCode(it) } }

    fun getResponseMessage(): String? =
        connectionInitiatingCallback { connectionDelegate.getResponseMessage() }

    fun disconnect() = connectionTerminatingCallback(forceStartIfPossible = false) { disconnect() }

    fun usingProxy(): Boolean = connectionDelegate.usingProxy()

    fun getErrorStream(): InputStream? =
        connectionDelegate.errorStream?.extractInputStreamBody()

    private fun Any.extractBody(): Any? =
        when (this) {
            is InputStream -> extractInputStreamBody()
            else -> this
        }

    private fun InputStream.extractInputStreamBody(): InputStream =
        runOrLogError {
            inputStreamDecorator?.takeIf { delegateInputStream == this }
                ?: let {
                    val weakInterceptor = WeakReference(this@HttpUrlConnectionInterceptor)
                    interceptionDelegate.consumeResponseBody(
                        responseHeaders,
                        responseBuilder.contentLength ?: 0L,
                        { this },
                        { body, length ->
                            weakInterceptor.get()?.apply {
                                runIfCanModify {
                                    responseBuilder.setBody(body)
                                        .setContentLength(length)
                                }
                                dispatchEndIfPossible()
                            }
                        }
                    ).also { wrapper ->
                        inputStreamDecorator = wrapper
                        delegateInputStream = this
                    }
                }
        }.getOrNull() ?: this

    private fun OutputStream.extractOutputStreamBody(): OutputStream =
        runOrLogError {
            outputStreamDecorator?.takeIf { delegateOutputStream == this }
                ?: let {
                    val weakInterceptor = WeakReference(this@HttpUrlConnectionInterceptor)
                    interceptionDelegate.consumeRequestBody(
                        requestHeaders,
                        requestBuilder.contentLength ?: 0L,
                        { this },
                        { body, length ->
                            weakInterceptor.get()?.apply {
                                runIfCanModify {
                                    requestBuilder.setBody(body)
                                        .setContentLength(length)
                                }
                            }
                        }
                    ).also { wrapper ->
                        outputStreamDecorator = wrapper
                        delegateOutputStream = this
                    }
                }
        }.getOrNull() ?: this

    private fun invalidateRequestHeaders() {
        runOrLogError {
            getRequestProperties()?.forEach {
                it.key?.let { key ->
                    it.joinHeaderValues()?.let { value ->
                        requestHeaders[key] = value
                    }
                }
            }
            interceptionDelegate.invalidateRequestHeaders(requestHeaders, logBuilder)
                ?.forEach {
                    if (!requestHeaders.containsKey(it.key)) {
                        setRequestProperty(it.key, it.value)
                    }
            }
        }
    }

    private inline fun <RET> connectionInitiatingCallback(
        terminateOnException: Boolean = true,
        callback: T.() -> RET
    ): RET {
        dispatchStartIfPossible()
        return if (terminateOnException) {
            connectionTerminationOnExceptionCallback(callback)
        } else {
            connectionDelegate.callback()
        }
    }

    private inline fun <RET> connectionTerminatingCallback(
        forceStartIfPossible: Boolean = true,
        callback: T.() -> RET
    ): RET {
        try {
            return connectionDelegate.callback()
        } finally {
            dispatchEndIfPossible(null, forceStartIfPossible)
        }
    }

    private inline fun <RET> connectionTerminationOnExceptionCallback(callback: T.() -> RET): RET =
        runCatching {
            connectionDelegate.callback()
        }.onFailure {
            dispatchEndIfPossible(it)
        }.getOrThrow()

    private fun dispatchStartIfPossible() {
        runOrLogError {
            if (ended) {
                return
            }
            if (!started) {
                started = true
                startTimeNano = TimeUtils.nanoTime()
                logBuilder.setStartTimeStampMs(TimeUtils.currentTimeMillis())
                processor.onNetworkRequestStarted(connectionDelegate)
            }
        }
    }

    private fun dispatchEndIfPossible(
        withThrowable: Throwable? = null,
        forceStartIfPossible: Boolean = true
    ) {
        runOrLogError {
            if (ended) {
                return
            }
            if (forceStartIfPossible) {
                dispatchStartIfPossible()
            }
            if (started) {
                ended = true
                startTimeNano
                    ?.let { TimeUtils.nanoTime() - it }
                    ?.let(TimeUnit.NANOSECONDS::toMicros)
                    ?.also(logBuilder::setDurationMus)
                withThrowable?.let { responseBuilder.setClientSideThrowable(it) }
                runOrLogError { IBGDiagnostics.logEvent(CAPTURE_EVENT) }
                processor.processNetworkLog(connectionDelegate, logBuilder)
            }
        }
    }

    private inline fun runIfCanModify(callBack: () -> Unit) {
        if (!ended) {
            callBack.invoke()
        }
    }

    private inline fun runSafelyIfCanModify(callBack: () -> Unit) =
        runOrLogError { runIfCanModify(callBack) }

    private fun Map.Entry<String?, List<String?>?>.joinHeaderValues() =
        value?.takeIf { it.isNotEmpty() }
            ?.filterNotNull()
            ?.takeIf { it.isNotEmpty() }
            ?.joinToString(",")


    companion object {
        private const val CAPTURE_EVENT = "urlconnection_log_collected"
    }
}
