package com.unity3d.services.core.network.core

import com.unity3d.ads.core.data.model.exception.UnityAdsNetworkException
import com.unity3d.services.core.domain.ISDKDispatchers
import com.unity3d.services.core.network.mapper.toOkHttpProtoRequest
import com.unity3d.services.core.network.mapper.toOkHttpRequest
import com.unity3d.services.core.network.model.HttpRequest
import com.unity3d.services.core.network.model.HttpResponse
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.withContext
import okhttp3.Call
import okhttp3.Callback
import okhttp3.OkHttpClient
import okhttp3.Response
import okio.Okio
import java.io.IOException
import java.net.SocketTimeoutException
import java.util.concurrent.TimeUnit
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException


/**
 * An implementation of [HttpClient] based on OkHttp
 * Supports Http2
 */
class OkHttp3Client(
    private val dispatchers: ISDKDispatchers,
    private val client: OkHttpClient,
) : HttpClient {

    /**
     * Helper method that blocks the thread to be used for Java interaction
     *
     * @param request [HttpRequest] to be executes on the network
     * @return [HttpResponse] of the passed in [HttpRequest]
     */
    override fun executeBlocking(request: HttpRequest): HttpResponse = runBlocking(dispatchers.io) {
        execute(request)
    }

    /**
     * Executes an http network request
     *
     * @param request [HttpRequest] to be executes on the network
     * @return [HttpResponse] of the passed in [HttpRequest]
     */
    override suspend fun execute(request: HttpRequest): HttpResponse = withContext(dispatchers.io) {
        try {
            val response = makeRequest(
                request,
                request.connectTimeout.toLong(),
                request.readTimeout.toLong(),
                request.writeTimeout.toLong(),
            )

            val responseBody = request.takeIf { request.downloadDestination != null }.let {
                if (request.isProtobuf) response.body()?.bytes() else response.body()?.string()
            }

            HttpResponse(
                statusCode = response.code(),
                headers = response.headers().toMultimap(),
                urlString = response.request().url().toString(),
                body = responseBody ?: "",
                protocol = response.protocol().toString(),
                client = NETWORK_CLIENT_OKHTTP
            )
        } catch (e: SocketTimeoutException) {
            throw UnityAdsNetworkException(
                message = MSG_CONNECTION_TIMEOUT,
                url = request.baseURL,
                client = NETWORK_CLIENT_OKHTTP
            )
        } catch (e: IOException) {
            throw UnityAdsNetworkException(
                message = MSG_CONNECTION_FAILED,
                url = request.baseURL,
                client = NETWORK_CLIENT_OKHTTP
            )
        }
    }

    /**
     * Wraps the OkHttp call callback in a coroutine with structured concurrency
     */
    private suspend fun makeRequest(
        request: HttpRequest,
        connectTimeout: Long,
        readTimeout: Long,
        writeTimeout: Long,
    ): Response = suspendCancellableCoroutine { continuation ->
        val okHttpRequest = if (request.isProtobuf) request.toOkHttpProtoRequest() else request.toOkHttpRequest()
        val configuredClient = client.newBuilder()
            .connectTimeout(connectTimeout, TimeUnit.MILLISECONDS)
            .readTimeout(readTimeout, TimeUnit.MILLISECONDS)
            .writeTimeout(writeTimeout, TimeUnit.MILLISECONDS)
            .build()

        configuredClient.newCall(okHttpRequest).enqueue(object : Callback {
            override fun onResponse(call: Call, response: Response) {
                try {
                    val file = request.downloadDestination
                    if (file?.exists() == true) {
                        val sink = Okio.buffer(Okio.sink(file))
                        sink.use {
                            response.body()?.source()?.let {
                                it.use { st ->
                                    sink.writeAll(st)
                                }
                            }
                        }
                    }
                    continuation.resume(response)
                } catch (e: Exception) {
                    continuation.resumeWithException(e)
                }
            }

            override fun onFailure(call: Call, e: IOException) {
                val exception = UnityAdsNetworkException(
                    message = MSG_CONNECTION_FAILED,
                    url = call.request().url().toString(),
                    client = NETWORK_CLIENT_OKHTTP
                )
                continuation.resumeWithException(exception)
            }
        })
    }

    companion object {
        const val MSG_CONNECTION_TIMEOUT = "Network request timeout"
        const val MSG_CONNECTION_FAILED = "Network request failed"
        const val NETWORK_CLIENT_OKHTTP = "okhttp"
    }
}
