--
-- (C) 2013-24 - ntop.org
--

package.path = dirs.installdir .. "/scripts/lua/modules/?.lua;" .. package.path
package.path = dirs.installdir .. "/scripts/lua/modules/vulnerability_scan/?.lua;" .. package.path


require "lua_utils"
local format_utils = require "format_utils"

local vs_rest_utils = {}

vs_rest_utils.ports_diff_case = {
    no_diff          = 2, -- case 1 or 2 (combined)
    ntopng_more_t_vs = 3,
    vs_more_t_ntopng = 4
 }


-- **********************************************************
-- ************** TO SCAN LIST REST FUNCTIONS ***************

-- Search port in ports list
local function find_port(port_to_find, port_list)
    for _, port in ipairs(port_list) do
       if(port_to_find == tonumber(port)) then
          return true
       end
    end
    return false
 end


-- Compare vs ports and ntopng detected ports
function vs_rest_utils.compare_ports(vs_scan_port_string_list, ntopng_ports)
    local vs_scan_ports = split(vs_scan_port_string_list, ",")
 
    local ports_unused = {}
    local filtered_ports = {}
    -- check vs_scan_ports with ntopng_ports
    local not_found_a_port = false
    for _,vs_port in ipairs(vs_scan_ports) do
       local find_actual_port = find_port(tonumber(vs_port), ntopng_ports)
 
       if (not find_actual_port) then
          not_found_a_port = true
          ports_unused[#ports_unused+1] = vs_port
       end
 
    end
 
    local diff_case
 
    if (not_found_a_port) then
       diff_case = vs_rest_utils.ports_diff_case.vs_more_t_ntopng
    end
 
    if (#vs_scan_ports == #ntopng_ports) then
       diff_case = vs_rest_utils.ports_diff_case.no_diff
    else
       local filtered = false
 
       for _,ntop_port in ipairs(ntopng_ports) do
          if (not find_port(tonumber(ntop_port), vs_scan_ports)) then
             filtered = true
             filtered_ports[#filtered_ports+1] = ntop_port
          end
 
       end
       if (filtered) then
          diff_case = vs_rest_utils.ports_diff_case.ntopng_more_t_vs
       end
    end
 
    return ports_unused, filtered_ports, diff_case
 end

-- Function to to compare ports detected by ntopng and ports discovered using nmap
local function get_ports_comparison_result(rsp, ports_string_list, ports_detected)

    local result = {}
    -- cases :
    -- 1: No host traffic but same vs ports and ntopng ports
    -- 2: Host traffic with same vs ports and ntopng ports
    -- 3: Host traffic and different ports (vs ports < ntopng ports)
    -- 4: Host traffic and different ports (vs ports > ntopng ports)
    if (isEmptyString(ports_string_list) and not next(ports_detected)) then
        -- vs_scan ports = 0; detected_ports = 0;
        -- no badge
 
        result.ports_case = vs_rest_utils.ports_diff_case.no_diff
    elseif ((not isEmptyString(ports_string_list)) and (not next(ports_detected))) then
        -- vs_scan ports != 0; detected_ports = 0;
        -- case 4
        result.ports_case = vs_rest_utils.ports_diff_case.vs_more_t_ntopng
        result.ports_unused = split(ports_string_list, ",")
    elseif (isEmptyString(ports_string_list) and (next(ports_detected))) then
        -- vs_scan ports = 0; detected_ports != 0;
        -- case 3
        result.ports_case = vs_rest_utils.ports_diff_case.ntopng_more_t_vs
        result.ports_filtered = ports_detected
    elseif ((not isEmptyString(ports_string_list)) and (next(ports_detected))) then
        -- vs_scan ports != 0; detected_ports != 0;
 
        -- could be:
        -- same ports with no traffic (case 1)
        -- same ports without traffic (case 2)
        -- different ports (case 3 or case 4)
 
        result.ports_unused, result.ports_filtered, result.ports_case =
            vs_rest_utils.compare_ports(ports_string_list, ports_detected)
    end
 
    return result
 
 end
 

-- **********************************************************
-- Retrieves detected ports by ntopng

function vs_rest_utils.retrieve_detected_ports(host)

    interface.select(interface.getId())
    
    local host_info = interface.getHostInfo(host)
 
    local tcp_ports_detected = {}
    local udp_ports_detected = {}
    local host_in_mem = false
    if (host_info and host_info.used_ports and host_info.used_ports.local_server_ports) then
       for port, l7_proto in pairs(host_info.used_ports.local_server_ports) do
          local port_details = split(port, ":")
          local id_port = port_details[2]
          local l4_proto = port_details[1]
 
          if (l4_proto == 'tcp') then
             tcp_ports_detected[#tcp_ports_detected+1] = id_port
          end
 
          if (l4_proto == 'udp') then
             udp_ports_detected[#udp_ports_detected+1] = id_port
          end
       end
       host_in_mem = true
    end
 
    return tcp_ports_detected, host_in_mem, udp_ports_detected
 end

 -- **********************************************************
 
 function vs_rest_utils.compare_scan_info_ntopng_info(host, scan_type ,tcp_ports_string_list, udp_ports_string_list)
    local tcp_ports_detected, host_in_mem, udp_ports_detected =
    vs_rest_utils.retrieve_detected_ports(host)
 
    local tcp_ports_compare_result = {}
    local udp_ports_compare_result = {}
 
    local rsp = {}
    if (scan_type == "tcp_portscan") then
       tcp_ports_compare_result = get_ports_comparison_result(rsp, tcp_ports_string_list,
          tcp_ports_detected)
 
       rsp.tcp_ports_unused = tcp_ports_compare_result.ports_unused
       rsp.tcp_ports_filtered = tcp_ports_compare_result.ports_filtered
       rsp.tcp_ports_case = tcp_ports_compare_result.ports_case
    else
       udp_ports_compare_result = get_ports_comparison_result(rsp, udp_ports_string_list,
          udp_ports_detected)
       rsp.udp_ports_unused = udp_ports_compare_result.ports_unused
       rsp.udp_ports_filtered = udp_ports_compare_result.ports_filtered
       rsp.udp_ports_case = udp_ports_compare_result.ports_case
    end
    rsp.host_in_mem = host_in_mem
 
    return rsp
 end

-- Function to convert ipv6 or ipv4 to hexadecimal int
local function ipv_to_hex(ip)
    -- Check if it's an IPv6 address
    if string.find(ip, ":") then
        local parts = {}
        for part in string.gmatch(ip, "([^:]+)") do
            table.insert(parts, part)
        end
        local hex_parts = {}
        for _, part in ipairs(parts) do
            -- Ensure each part has at least 4 characters by padding with zeros
            part = string.format("%04s", part)
            table.insert(hex_parts, part)
        end
        return table.concat(hex_parts, ":")
    else
        -- IPv4 address
        local parts = {}
        for part in string.gmatch(ip, "([^.]+)") do
            table.insert(parts, part)
        end
        local hex_parts = {}
        for _, part in ipairs(parts) do
            local hex_part = string.format("%02X", tonumber(part))
            table.insert(hex_parts, hex_part)
        end
        return table.concat(hex_parts, ".")
    end
end

-- ##################################################################

-- Function to check if a spceific port is in the ports_list string
local function portCheck(tcp_ports_list, port)
    if (isEmptyString(port)) then
        return true
    else
        local ports = split(tcp_ports_list, ",")
        for _, item in ipairs(ports) do
            if (item == port) then
                return true
            end
        end

        return false
    end
end

-- ##################################################################

-- Function to format epoch
local function format_epoch(value)
    if (value.last_scan ~= nil and value.last_scan.epoch ~= nil) then
        return format_utils.formatPastEpochShort(value.last_scan.epoch)
    else
        return value.last_scan.time
    end
end

local function format_back_duration_epoch(value)
    if (value.last_scan ~= nil and value.last_scan.duration ~= nil) then
        return format_utils.timeToSeconds(value.last_scan.duration)
    else
        return value.last_scan.duration
    end
end

-- ##################################################################

function vs_rest_utils.format_port_label(port, service_name, protocol)
    if (isEmptyString(service_name)) then
       return string.format("%s/%s",port,protocol)
    else
       return string.format("%s/%s (%s)",port,protocol,service_name)
    end
 end

-- ##################################################################

-- Function to format port_list string with service names
local function format_port_list(ports_string_list, protocol)

    local formatted_ports_list = ""
    for index, port in ipairs(split(ports_string_list, ',')) do
        local service_name = mapServiceName(port, protocol)
        local port_label = vs_rest_utils.format_port_label(port, service_name, protocol)
        if (index == 1) then
            formatted_ports_list = port_label
        else
            formatted_ports_list = string.format("%s,%s", formatted_ports_list, port_label)
        end
    end

    return formatted_ports_list
end

-- ##################################################################

-- Function compare for sort
local function compare_host(a, b)

    local a_tmp = ipv_to_hex(a.host)
    local b_tmp = ipv_to_hex(b.host)
    if (a_tmp == b_tmp and (a.last_scan and b.last_scan)) then
        return a.last_scan.epoch < b.last_scan.epoch
    end
    return a_tmp < b_tmp
end



-- ##################################################################

-- Function to format result
function vs_rest_utils.format_overview_result(result, search_map, sort, port, was_down, netscan_report)
    local rsp = {}
    if result then

        for _, value in ipairs(result) do
            if (was_down) then
                if value.was_down == false or value.was_down == nil then
                    goto continue
                end
            end

            if (netscan_report) then
                
                if value.scan_type ~= 'ipv4_netscan' then
                    goto continue
                end
            end
            local tcp_ports_string_list = value.tcp_ports_list
            local udp_ports_string_list = value.udp_ports_list

            -- FIX for early development versions
            if (value.scan_type == "tcp_openports") then
                value.scan_type = "tcp_portscan"
            end
            if (value.scan_type == "udp_openports") then
                value.scan_type = "udp_portscan"
            end

            -- FIX ME with udp port check
            if (portCheck(tcp_ports_string_list, port) or portCheck(udp_ports_string_list, port)) then
                if (isEmptyString(search_map)) then
                    rsp[#rsp + 1] = value
                    rsp[#rsp].num_vulnerabilities_found = format_high_num_value_for_tables(value,
                        "num_vulnerabilities_found")
                    rsp[#rsp].num_open_ports = format_high_num_value_for_tables(value, "num_open_ports")
                    rsp[#rsp].tcp_ports = format_high_num_value_for_tables(value, "tcp_ports")
                    rsp[#rsp].udp_ports = format_high_num_value_for_tables(value, "udp_ports")
                    if (rsp[#rsp].tcp_ports == 0 and rsp[#rsp].udp_ports == 0) then
                        rsp[#rsp].tcp_ports = rsp[#rsp].num_open_ports
                    end
                    if (rsp[#rsp].last_scan) then
                        rsp[#rsp].last_scan.time = format_epoch(value)
                        rsp[#rsp].last_scan.duration_epoch = format_back_duration_epoch(value)

                    end
                else
                    if (value.host == search_map or string.find(value.host, search_map) or
                        string.find((value.host_name or ""), search_map)) then
                        rsp[#rsp + 1] = value
                        rsp[#rsp].num_vulnerabilities_found =
                            format_high_num_value_for_tables(value, "num_vulnerabilities_found")
                        rsp[#rsp].num_open_ports = format_high_num_value_for_tables(value, "num_open_ports")
                        if (rsp[#rsp].last_scan) then
                            rsp[#rsp].last_scan.time = format_epoch(value)
                            rsp[#rsp].last_scan.duration_epoch = format_back_duration_epoch(value)
                        end
                    end
                end

                if (next(rsp) and not isEmptyString(tcp_ports_string_list)) then
                    rsp[#rsp].tcp_ports_list = format_port_list(tcp_ports_string_list, "tcp")
                end

                if (next(rsp) and not isEmptyString(udp_ports_string_list)) then
                    rsp[#rsp].udp_ports_list = format_port_list(udp_ports_string_list, "udp")
                end
            end

            if (next(rsp)) then
                
                local cmp_result = vs_rest_utils.compare_scan_info_ntopng_info(rsp[#rsp].host, rsp[#rsp].scan_type ,tcp_ports_string_list, udp_ports_string_list)
                if (rsp[#rsp].scan_type == "tcp_portscan") then
                    rsp[#rsp].tcp_ports_unused = cmp_result.tcp_ports_unused
                    rsp[#rsp].tcp_ports_filtered = cmp_result.tcp_ports_filtered
                    rsp[#rsp].tcp_ports_case = cmp_result.tcp_ports_case
                else
                    rsp[#rsp].udp_ports_unused = cmp_result.udp_ports_unused
                    rsp[#rsp].udp_ports_filtered = cmp_result.udp_ports_filtered
                    rsp[#rsp].udp_ports_case = cmp_result.udp_ports_case
                end
                rsp[#rsp].host_in_mem = cmp_result.host_in_mem
            end
            ::continue::
        end

        if not isEmptyString(sort) and sort == 'ip' then
            table.sort(rsp, compare_host)
        end
    end
    return rsp
end
-- ************** TO SCAN LIST REST FUNCTIONS *************

-- ********************************************************

-- ************** TO SCAN PORTS LIST REST FUNCTIONS *******
 
-- ####################################################

-- Function to find if the port is already present in the rest response
local function find_in_response(rest_response,port) 
    for id,item in ipairs(rest_response) do
        if (item.port == port) then
            return id
        end
    end
    return false
end

-- ####################################################

-- Function compare for sort on cves 
local function compare_cve(a,b) 
    local a_tmp = a.cves
    local b_tmp = b.cves

    -- handling nill cases
    if (a_tmp == nil) then
        a_tmp = 0
    end
    if (b_tmp == nil) then
        b_tmp = 0
    end

    -- on same cves it will sort on port_number
    if (a_tmp == b_tmp) then
        return (tonumber(a.port_number) < tonumber(b.port_number))
    end
    return a_tmp > b_tmp
end


-- ####################################################

-- Function to format epoch
local function add_date(value) 
    if(value.last_scan ~= nil) then
        if (value.last_scan.epoch ~= nil) then
            return format_utils.formatPastEpochShort(value.last_scan.epoch)
        else
            return value.last_scan.time
        end

    end

    return ""
end

-- ####################################################

-- Function to add a new element to the rest_response or update it
local function handle_element(rest_response, id, value, port, protocol, sort) 

    if id ~= false and id ~= nil then

        if (string.contains(rest_response[id].hosts, value.host)) then
            -- nothing to do
        else
            rest_response[id].count_host = rest_response[id].count_host + 1
            rest_response[id].hosts = rest_response[id].hosts .. ", "..value.host.."|"..value.scan_type.."|"..add_date(value).."|"..tostring(isIPv4(value.host)).."|"..value.last_scan.epoch
            if ( not isEmptyString(value.host_name)) then
                rest_response[id].hosts = rest_response[id].hosts .. "|"..value.host_name
            end
        end
        rest_response[id].cves = rest_response[id].cves + value.num_vulnerabilities_found

    elseif(id ~= nil) then
        local port_label = mapServiceName(port, protocol)

        local port_id = string.format("%s/%s",port,protocol)

        if((not isEmptyString(sort)) and isEmptyString(port_label)) then
            port_label = port_id
        end
        local hosts_first_elem = value.host.."|"..value.scan_type.."|"..add_date(value).."|"..tostring(isIPv4(value.host)).."|"..value.last_scan.epoch
        if ( not isEmptyString(value.host_name)) then
            hosts_first_elem = hosts_first_elem .. "|"..value.host_name
        end
        local new_item = {
            count_host = 1,
            hosts = hosts_first_elem,
            port = port_id,
            port_number = port,
            port_label = port_label,
            service_name = mapServiceName(port, protocol),
            cves = value.num_vulnerabilities_found
        }
        if ( not isEmptyString(value.host_name) ) then
            new_item.hosts = new_item.hosts .. "|"..value.host_name
        end

        rest_response[#rest_response+1] = new_item
    end

    return rest_response
end
-- ####################################################

-- Function to verify if there's a cve value (for report)
local function search_cve(rest_response) 
    local has_cve = false
    for _,item in ipairs(rest_response) do
        if (item.cves ~= nil and item.cves ~= 0) then
            has_cve = true
        end
    end
    return(has_cve)
end


-- ####################################################

-- Function to format rest result
function vs_rest_utils.format_scan_port_list_result(result, l4_protocol, sort) 
    local rest_response = {}
    if result then
        for _,value in ipairs(result) do
            
            if ((value.tcp_ports_list ~= nil) and (isEmptyString(l4_protocol) or l4_protocol == "tcp"))then
                local tcp_ports_list = split(value.tcp_ports_list, ",")
                for _, port in ipairs(tcp_ports_list) do
                    if (not isEmptyString(port)) then
                        local id = find_in_response(rest_response,port.."/tcp")
                        rest_response = handle_element(rest_response, id, value, port, "tcp", sort)
                    end
                end
            end


            if ((value.udp_ports_list ~= nil) and (isEmptyString(l4_protocol) or l4_protocol == "udp"))then
                local udp_ports_list = split(value.udp_ports_list, ",")
                for _, port in ipairs(udp_ports_list) do
                    if (not isEmptyString(port)) then
                        local id = find_in_response(rest_response,port.."/udp")
                        rest_response = handle_element(rest_response, id, value, port, "udp", sort)

                    end
                end
            end

        end
    end


    if (isEmptyString(sort)) then
    elseif sort == 'port' then
        table.sort(rest_response, function (k1, k2) return tonumber(k1.port_number) < tonumber(k2.port_number) end )
    elseif sort == 'cve' then
        if (search_cve(rest_response)) then
            table.sort(rest_response, compare_cve)
        else
            rest_response = {}
        end
    end

    for _,item in ipairs(rest_response) do
        item.count_host = format_high_num_value_for_tables(item, "count_host")
        item.cves = format_high_num_value_for_tables(item, "cves")
    end

    return rest_response
end

-- ####################################################

return vs_rest_utils