package com.instabug.library.apm_okhttp_event_listener

import androidx.annotation.VisibleForTesting
import com.instabug.apm.model.EventTimeMetricCapture
import com.instabug.library.apm_network_log_repository.APMNetworkLogRepository
import com.instabug.library.apmokhttplogger.model.OkHttpAPMNetworkLog
import com.instabug.library.diagnostics.IBGDiagnostics
import com.instabug.library.map.Mapper
import okhttp3.Call
import java.util.WeakHashMap

interface NetworkLatencyEventCaptor {
    fun callStart(call: Call)
    fun dnsStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun dnsEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun connectStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun secureConnectStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun secureConnectEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun connectEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun requestHeadersStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun requestHeadersEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun requestBodyStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun requestBodyEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun responseHeadersStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun responseHeadersEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun responseBodyStart(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun responseBodyEnd(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun callEnd(call: Call)
    fun onStageFailed(call: Call, eventTimeMetric: EventTimeMetricCapture)
    fun callFailed(call: Call, eventTimeMetric: EventTimeMetricCapture)

}

class NetworkLatencyEventCaptorImpl(
    private val networkLogExecutor: (Runnable) -> Unit,
    private val networkLatencySpansMapper: Mapper<Array<EventTimeMetricCapture?>, String?>,
    private val networkLogRepository: APMNetworkLogRepository
) : NetworkLatencyEventCaptor {

    @VisibleForTesting
    val eventTimesMap: MutableMap<Call, Array<EventTimeMetricCapture?>> =
        WeakHashMap()

    override fun callStart(call: Call) = executeSafely {
        networkLogRepository.start(call)
        createEventTimesArray(call)
    }

    override fun dnsStart(call: Call, eventTimeMetric: EventTimeMetricCapture) = executeSafely {
        captureRequestStartedIfPossible(call, eventTimeMetric)
        captureEventTimeMetric(call, LatencyEvent.DNS_START, eventTimeMetric)
    }

    override fun dnsEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely { captureEventTimeMetric(call, LatencyEvent.DNS_END, eventTimeMetric) }

    override fun connectStart(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureRequestStartedIfPossible(call, eventTimeMetric)
            captureEventTimeMetric(call, LatencyEvent.CONNECT_START, eventTimeMetric)
        }

    override fun secureConnectStart(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.SECURE_CONNECT_START, eventTimeMetric)
        }

    override fun secureConnectEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.SECURE_CONNECT_END, eventTimeMetric)
        }

    override fun connectEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely { captureEventTimeMetric(call, LatencyEvent.CONNECT_END, eventTimeMetric) }

    override fun requestHeadersStart(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureRequestStartedIfPossible(call, eventTimeMetric)
            captureEventTimeMetric(call, LatencyEvent.REQUEST_HEADERS_START, eventTimeMetric)
        }

    override fun requestHeadersEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.REQUEST_HEADERS_END, eventTimeMetric)
        }

    override fun requestBodyStart(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureRequestStartedIfPossible(call, eventTimeMetric)
            captureEventTimeMetric(call, LatencyEvent.REQUEST_BODY_START, eventTimeMetric)
        }

    override fun requestBodyEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.REQUEST_BODY_END, eventTimeMetric)
        }

    override fun responseHeadersStart(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.RESPONSE_HEADERS_START, eventTimeMetric)
        }

    override fun responseHeadersEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.RESPONSE_HEADERS_END, eventTimeMetric)
        }

    override fun responseBodyStart(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.RESPONSE_BODY_START, eventTimeMetric)
        }

    override fun responseBodyEnd(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely {
            captureEventTimeMetric(call, LatencyEvent.RESPONSE_BODY_END, eventTimeMetric)
        }

    override fun callEnd(call: Call) = executeSafely { end(call) }

    override fun onStageFailed(call: Call, eventTimeMetric: EventTimeMetricCapture) =
        executeSafely { captureEventTimeMetric(call, LatencyEvent.REQUEST_FAILED, eventTimeMetric) }

    override fun callFailed(call: Call, eventTimeMetric: EventTimeMetricCapture) = executeSafely {
        captureRequestFailedIfPossible(call, eventTimeMetric)
        end(call)
    }

    private fun end(call: Call) = synchronized(call) {
        networkLogRepository[call]?.collectCapturedData(call)
        networkLogRepository.end(call)
        eventTimesMap.remove(call)
    }

    private fun OkHttpAPMNetworkLog.collectCapturedData(call: Call) = synchronized(call) {
        eventTimesMap[call]?.let {
            startTime = it[LatencyEvent.REQUEST_STARTED]?.getTimeStampMicro() ?: 0
            startTimeNanos = it[LatencyEvent.REQUEST_STARTED]?.getNanoTime() ?: 0
            endTimeNanos = it.maxOfOrNull { timeMetric -> timeMetric?.getNanoTime() ?: 0 } ?: 0
            latencySpansJsonString = networkLatencySpansMapper.map(it)
        }
    }

    private fun captureEventTimeMetric(
        call: Call,
        @LatencyEvent event: Int,
        eventTimeMetric: EventTimeMetricCapture,
    ) = synchronized(call) {
        getEventTimesArray(call)?.apply {
            clearRedirectDataIfPossible(event, call)
            this[event] = eventTimeMetric
        }
    }

    private fun captureRequestStartedIfPossible(
        call: Call,
        eventTimeMetric: EventTimeMetricCapture
    ) = captureNonOverrideEventTimeMetric(call, LatencyEvent.REQUEST_STARTED, eventTimeMetric)

    private fun captureRequestFailedIfPossible(
        call: Call,
        eventTimeMetric: EventTimeMetricCapture
    ) = captureNonOverrideEventTimeMetric(call, LatencyEvent.REQUEST_FAILED, eventTimeMetric)

    private fun captureNonOverrideEventTimeMetric(
        call: Call,
        @LatencyEvent event: Int,
        eventTimeMetric: EventTimeMetricCapture,
    ) = synchronized(call) {
        getEventTimesArray(call)?.apply {
            if (this[event] == null) this[event] = eventTimeMetric
        }
    }

    private fun Array<EventTimeMetricCapture?>.clearRedirectDataIfPossible(event: Int, call: Call) =
        synchronized(call) { if (get(event) != null) for (i in event..lastIndex) this[i] = null }

    private fun getEventTimesArray(call: Call): Array<EventTimeMetricCapture?>? =
        eventTimesMap[call]

    private fun createEventTimesArray(call: Call) =
        Array<EventTimeMetricCapture?>(LatencyEvent.TOTAL_COUNT) { null }
            .let { eventTimesMap[call] = it }

    private inline fun executeSafely(crossinline action: () -> Unit) = networkLogExecutor {
        try {
            action()
        } catch (t: Throwable) {
            IBGDiagnostics.reportNonFatal(
                t, "Error occurred while capturing network latency spans: ${t.message}"
            )
        }
    }
}