package com.transsin.networkmonitor.trace

import android.os.Process
import android.util.Log
import okhttp3.Interceptor
import okhttp3.Response
import java.io.IOException
import java.net.*
import java.util.*
import java.util.concurrent.atomic.AtomicInteger

class TraceIdInterceptor : Interceptor {
    private val index = AtomicInteger(INCR_MIN)

    companion object {
        private val random = Random()
        private val digits = charArrayOf(
            '0', '1', '2', '3', '4', '5',
            '6', '7', '8', '9', 'a', 'b',
            'c', 'd', 'e', 'f', 'g', 'h',
            'i', 'j', 'k', 'l', 'm', 'n',
            'o', 'p', 'q', 'r', 's', 't',
            'u', 'v', 'w', 'x', 'y', 'z'
        )

        private const val INCR_MIN = 100000
        private const val INCR_MAX = 999999

        fun toHexStr(i: Long): String {
            return toUnsignedString0(i, 4)
        }

        fun toUnsignedString0(`val`: Long, shift: Int): String {
            // assert shift > 0 && shift <=5 : "Illegal shift value";
            val mag = java.lang.Long.SIZE - java.lang.Long.numberOfLeadingZeros(`val`)
            val chars = Math.max((mag + shift - 1) / shift, 1)
            val buf = CharArray(16)
            val offset = 16 - chars
            for (i in 0 until offset) {
                buf[i] = '0'
            }

            formatUnsignedLong(`val`, shift, buf, 16 - chars, chars)
            return String(buf)
        }

        fun formatUnsignedLong(
            `val`: Long,
            shift: Int,
            buf: CharArray,
            offset: Int,
            len: Int
        ): Int {
            var charPos = len
            var value = `val`
            val radix = 1 shl shift
            val mask = radix - 1
            do {
                buf[offset + --charPos] = digits[(value.toInt() and mask)]
                value = value ushr shift
            } while (value != 0L && charPos > 0)

            return charPos
        }
    }

    @Throws(IOException::class)
    override fun intercept(chain: Interceptor.Chain): Response {
        val clientIp = getClientIpFromChain(chain)
        val traceId = genTraceId(clientIp)
        Log.d("TraceIdInterceptor", "generate trace id:$traceId")

        val request = chain.request().newBuilder()
            .addHeader("traceparent", traceId)
            .build()

        return chain.proceed(request)
    }

    fun genTraceId(clientIp: String): String {
        val sb = StringBuilder()
        sb.append("00-")
        sb.append(convertIPToHex(clientIp))
        sb.append(System.currentTimeMillis())
        sb.append(getCurrentProcessID())
        sb.append(getNextIndex())
        sb.append("-")
        sb.append(genSpanId())
        sb.append("-01")
        return sb.toString()
    }

    fun genSpanId(): String {
        return toHexStr(random.nextLong())
    }

    private fun getClientIpFromChain(chain: Interceptor.Chain): String {
        return try {
            val route = chain.connection()?.route()
            if (route != null) {
                val proxy = route.proxy()
                val proxyAddress = proxy.address() as? InetSocketAddress
                if (proxy.type() != Proxy.Type.DIRECT && proxyAddress != null) {
                    return proxyAddress.hostString ?: getLocalIpAddress()
                }
            }
            getLocalIpAddress()
        } catch (e: Exception) {
            getLocalIpAddress()
        }
    }

    private fun getLocalIpAddress(): String {
        return try {
            for (networkInterface in Collections.list(NetworkInterface.getNetworkInterfaces())) {
                for (inetAddress in Collections.list(networkInterface.inetAddresses)) {
                    if (!inetAddress.isLoopbackAddress && inetAddress is Inet4Address) {
                        return inetAddress.hostAddress
                    }
                }
            }
            "0.0.0.0"
        } catch (e: Exception) {
            "0.0.0.0"
        }
    }

    private fun convertIPToHex(clientIp: String): String {
        val items = clientIp.split("\\.".toRegex()).toTypedArray()
        val bytes = ByteArray(4)

        for (i in 0..3) {
            bytes[i] = Integer.parseInt(items[i]).toByte()
        }

        val sb = StringBuilder(bytes.size / 2)

        for (b in bytes) {
            sb.append(Integer.toHexString((b.toInt() shr 4) and 0x0F))
            sb.append(Integer.toHexString(b.toInt() and 0x0F))
        }
        return sb.toString()
    }

    protected fun getNextIndex(): Int {
        var c = 3
        while (c-- > 0) {
            val num = index.incrementAndGet()
            if (num > INCR_MAX) {
                if (index.compareAndSet(num, INCR_MIN)) {
                    return INCR_MIN
                }
                continue
            }
            return num
        }
        val lastNum = index.incrementAndGet()
        return if (lastNum > INCR_MAX) {
            lastNum % INCR_MIN + INCR_MIN
        } else {
            lastNum
        }
    }

    private fun getCurrentProcessID(): String {
        return Process.myPid().toString()
    }
}