#!/usr/bin/env python3
#
# netplan-esxi - A netplan implementation for VMware ESXi
#
# Author: Lee Trager <lee.trager@canonical.com>
#
# Copyright (C) 2019-2025 Canonical
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import re
import sys
from argparse import ArgumentParser
from subprocess import check_call, check_output

import ipaddr
import yaml


def warn(msg):
    """Print a standard warning message."""
    print("WARNING: %s" % msg, file=sys.stderr)


def help(parser):
    """Print the parser help and exit."""
    parser.print_help(file=sys.stderr)
    sys.exit(os.EX_USAGE)


def get_list_names(command):
    """Return all names from a network list command."""
    output = check_output(["esxcli", "network"] + command + ["list"])
    return re.findall(r"^\s+Name:\s+(\w+)$", output.decode(), re.M)


def wipe():
    """Wipe the existing network config from the running system."""
    print("Wiping current network configuration...")

    for vmk in get_list_names(["ip", "interface"]):
        check_call(
            ["esxcli", "network", "ip", "interface", "remove", "-i", vmk]
        )

    for vswitch in get_list_names(["vswitch", "standard"]):
        check_call(
            [
                "esxcli",
                "network",
                "vswitch",
                "standard",
                "remove",
                "-v",
                vswitch,
            ]
        )

    check_call(["esxcli", "network", "ip", "dns", "server", "remove", "--all"])
    output = check_output(["esxcli", "network", "ip", "dns", "search", "list"])
    output = output.decode().split(":", 2)[1]
    for search in output.split(","):
        search = search.strip()
        if not search:
            continue
        check_call(
            [
                "esxcli",
                "network",
                "ip",
                "dns",
                "search",
                "remove",
                "-d",
                search,
            ]
        )

    print("Done wiping network configuration!")


def get_vmnics():
    """Return a dictionary of MAC addresses to vmnic names."""
    output = check_output(["esxcli", "network", "nic", "list"])
    regex = re.compile(
        r"^(?P<vmnic>vmnic\d+).*(?P<mac>(\w{2}:){5}\w{2}).*$", re.I
    )
    ret = {}
    for line in output.decode().splitlines():
        m = regex.search(line)
        if m:
            ret[m.group("mac")] = m.group("vmnic")
    return ret


def map_id_to_mac(configs):
    """Map the nic id(e.g ens3) to MAC addresses in the config."""
    ret = {}
    for config in configs:
        if config.get("type") == "physical":
            ret[config.get("id")] = config.get("mac_address")
    return ret


def create_vswitch(mtu, counters):
    """Create a VMware vSwitch on the system."""
    vswitch = "vSwitch%s" % counters["vswitch"]
    check_call(
        ["esxcli", "network", "vswitch", "standard", "add", "-v", vswitch]
    )
    counters["vswitch"] += 1
    if mtu:
        check_call(
            [
                "esxcli",
                "network",
                "vswitch",
                "standard",
                "set",
                "-m",
                str(mtu),
                "-v",
                vswitch,
            ]
        )
    return vswitch


def nic_to_vmnic(nic, mappings):
    """Convert a nic id(e.g ens3) to a vmnic(e.g vmnic0)."""
    mac = mappings["ids"].get(nic)
    return mappings["vmnics"].get(mac)


def add_nic_to_vswitch(nic, vswitch, mappings):
    """Add a given nic id(e.g ens3) to a vSwitch."""
    vmnic = nic_to_vmnic(nic, mappings)
    if vmnic is None:
        warn("Unable to map %s!" % nic)
        return None
    check_call(
        [
            "esxcli",
            "network",
            "vswitch",
            "standard",
            "uplink",
            "add",
            "-u",
            vmnic,
            "-v",
            vswitch,
        ]
    )
    return vmnic


def add_vmk(port_group, mtu, mac_address, counters):
    """Add a vmk and associated port group to the given switch."""
    vmk = "vmk%s" % counters["vmk"]
    cmd = [
        "esxcli",
        "network",
        "ip",
        "interface",
        "add",
        "-i",
        vmk,
        "-p",
        port_group,
        "-m",
        str(mtu),
    ]
    if mac_address is not None:
        cmd += ["-M", mac_address]
    check_call(cmd)
    counters["vmk"] += 1
    return vmk


def _route_exists(network, gateway, ip_ver):
    """Check if the given network is already being routed to."""
    output = check_output(["esxcli", "network", "ip", "route", ip_ver, "list"])
    for line in output.decode().splitlines():
        if (
            str(network.ip) in line
            and str(network.netmask) in line
            and gateway in line
        ):
            # Static route has already been added
            return True
    return False


def _apply_static_routes(routes):
    """Apply all defined static routes."""
    for route in routes:
        for key in ["destination", "gateway"]:
            if key not in route:
                warn('Missing key "%s" for static route!' % key)
                break
        network = ipaddr.IPNetwork(route["destination"])
        if network.ip.version == 4:
            ip_ver = "ipv4"
        elif network.ip.version == 6:
            ip_ver = "ipv6"
        else:
            warn("Unknown IP version %s" % network.ip.version)
            continue
        if not _route_exists(network, route["gateway"], ip_ver):
            print(
                "Adding static route to %s via %s"
                % (route["destination"], route["gateway"])
            )
            check_call(
                [
                    "esxcli",
                    "network",
                    "ip",
                    "route",
                    ip_ver,
                    "add",
                    "-g",
                    route["gateway"],
                    "-n",
                    route["destination"],
                ]
            )


def _apply_subnets(config, vswitch, counters):
    """IP configuration for the given subnets."""
    mtu = config.get("mtu")
    port_group_name = config.get("name")
    if port_group_name is None:
        warn("No name given!")
        return
    vlan_id = config.get("vlan_id")
    mac_address = config.get("mac_address")

    for i, subnet in enumerate(config.get("subnets", [])):
        if i == 0:
            port_group = port_group_name
        else:
            # If there is more then one subnet the others are alias, use the
            # same naming convention MAAS does.
            port_group = "%s:%s" % (port_group_name, i)
            print("Creating alias %s" % port_group)
        check_call(
            [
                "esxcli",
                "network",
                "vswitch",
                "standard",
                "portgroup",
                "add",
                "-p",
                port_group,
                "-v",
                vswitch,
            ]
        )
        if vlan_id is not None:
            check_call(
                [
                    "esxcli",
                    "network",
                    "vswitch",
                    "standard",
                    "portgroup",
                    "set",
                    "-p",
                    port_group,
                    "-v",
                    str(vlan_id),
                ]
            )

        vmk = add_vmk(port_group, mtu, mac_address, counters)
        cmd = ["esxcli", "network", "ip", "interface"]
        stype = subnet.get("type")
        if stype == "static":
            ip_net = ipaddr.IPNetwork(subnet["address"])
            if ip_net.version == 4:
                print(
                    "Configuring static IPv4 address(%s)" % subnet["address"]
                )
                cmd += [
                    "ipv4",
                    "set",
                    "-i",
                    vmk,
                    "-t",
                    "static",
                    "-I",
                    str(ip_net.ip),
                    "-N",
                    str(ip_net.netmask),
                ]
                gateway = subnet.get("gateway")
                check_call(cmd)
                if gateway:
                    check_call(
                        [
                            "esxcli",
                            "network",
                            "ip",
                            "route",
                            "ipv4",
                            "add",
                            "--gateway=%s" % gateway,
                            "--network=default",
                        ]
                    )
            elif ip_net.version == 6:
                print(
                    "Configuring static IPv6 address(%s)" % subnet["address"]
                )
                check_call(
                    cmd
                    + [
                        "ipv6",
                        "address",
                        "add",
                        "-i",
                        vmk,
                        "-I",
                        str(ip_net.ip),
                    ]
                )
                gateway = subnet.get("gateway")
                if gateway:
                    check_call(cmd + ["ipv6", "set", "-i", vmk, "-g", gateway])
                    check_call(
                        [
                            "esxcli",
                            "network",
                            "ip",
                            "route",
                            "ipv6",
                            "add",
                            "--gateway=%s" % gateway,
                            "--network=default",
                        ]
                    )
            else:
                warn("Unknown IP version IPv%s, skipping" % ip_net.version)
        elif subnet["type"] == "dhcp4":
            print("Configuring DHCPv4")
            cmd += ["ipv4", "set", "-i", vmk, "-t", "dhcp"]
            gateway = subnet.get("gateway")
            if gateway:
                cmd += ["-g", gateway]
            check_call(cmd)
        elif subnet["type"] == "dhcp6":
            print("Configuring DHCPv6")
            # Should IPv6 router advertisement be enabled?(-r true)
            cmd += ["ipv6", "set", "-i", vmk, "-e", "true", "-d", "true"]
            gateway = subnet.get("gateway")
            if gateway:
                cmd += ["-g", gateway]
            check_call(cmd)
        else:
            warn("Unknown subnet type %s, skipping" % subnet["type"])
            continue
        _apply_static_routes(subnet.get("routes", []))


def _apply_physical(config, subnets, mappings, counters):
    """Configure a physical device."""
    # Don't do anything if no subnets are configured. This device is either
    # unconfigured or used for a bond/VLAN.
    if subnets == []:
        return

    print("Configuring physical(%s)..." % config.get("name"))

    vswitch = create_vswitch(config.get("mtu"), counters)
    mappings["vswitches"][config["id"]] = vswitch
    add_nic_to_vswitch(config.get("id"), vswitch, mappings)

    _apply_subnets(config, vswitch, counters)


def _apply_bond(config, subnets, mappings, counters):
    """Configure a bond using a vSwitch with NIC teaming."""
    print("Configuring bond(%s)..." % config.get("name"))

    bond_interfaces = config.get("bond_interfaces", [])
    mac_address = config.get("mac_address")
    # When using explict NIC teaming ESXi uses the order of the given
    # uplinks to determine which is the active and which are the backup
    # NICs. Netplan V1 specifies which NIC is active by MAC address.
    uplinks = ""
    if mac_address is not None:
        uplinks = mappings["vmnics"].get(mac_address, "None")
    for nic in bond_interfaces:
        vmnic = nic_to_vmnic(nic, mappings)
        if vmnic not in uplinks:
            if uplinks != "":
                uplinks += ","
            uplinks += str(vmnic)
    if "None" in uplinks or uplinks == "":
        warn("Unable to map all uplinks!")
        return

    vswitch = create_vswitch(config.get("mtu"), counters)
    mappings["vswitches"][config["id"]] = vswitch
    for bond_interface in bond_interfaces:
        add_nic_to_vswitch(bond_interface, vswitch, mappings)

    cmd = [
        "esxcli",
        "network",
        "vswitch",
        "standard",
        "policy",
        "failover",
        "set",
        "-v",
        vswitch,
        "-a",
        uplinks,
    ]
    bond_mode = config.get("params", {}).get("bond-mode")
    if bond_mode == "balance-rr":
        cmd += ["-l", "portid"]
    elif bond_mode == "active-backup":
        cmd += ["-l", "explicit"]
    elif bond_mode == "802.3ad":
        warn("LACP rate and XMIT hash policy ignored.")
        cmd += ["-l", "iphash"]
    else:
        warn("Unsupported bond mode(%s)" % bond_mode)
        return
    check_call(cmd)

    _apply_subnets(config, vswitch, counters)


def _apply_vlan(config, subnets, mappings, counters):
    """Configure a VLAN."""
    print("Configuring VLAN(%s)..." % config.get("name"))

    vlan_link = config.get("vlan_link")
    vswitch = mappings["vswitches"].get(vlan_link)
    # A vSwitch is only found if the vlan_link is a bond.
    if vswitch is None:
        vswitch = create_vswitch(config.get("mtu"), counters)
        mappings["vswitches"][vlan_link] = vswitch
        add_nic_to_vswitch(vlan_link, vswitch, mappings)

    _apply_subnets(config, vswitch, counters)


def _apply_nameserver(config, subnets, mappings, counters):
    """Configure DNS."""
    print("Configuring nameservers...")

    for address in config.get("address", []):
        check_call(
            ["esxcli", "network", "ip", "dns", "server", "add", "-s", address]
        )

    for search in config.get("search", []):
        check_call(
            ["esxcli", "network", "ip", "dns", "search", "add", "-d", search]
        )


def apply(configs):
    """Apply the specified Netplan config to the running system."""
    print("Applying network configuration...")
    mappings = {
        "vmnics": get_vmnics(),
        "ids": map_id_to_mac(configs),
        "vswitches": {},
    }
    counters = {
        "vswitch": 0,
        "vmk": 0,
    }
    types = {
        "physical": _apply_physical,
        "bond": _apply_bond,
        "vlan": _apply_vlan,
        "nameserver": _apply_nameserver,
    }

    for config in configs:
        subnets = [
            subnet
            for subnet in config.get("subnets", [])
            if subnet.get("type", "manual") != "manual"
        ]
        ctype = config.get("type")
        apply_func = types.get(ctype)
        if apply_func:
            apply_func(config, subnets, mappings, counters)
        else:
            warn("Unknown configuration type %s" % ctype)

    print("Done applying network configuration!")


def main():
    parser = ArgumentParser(
        description=(
            "Apply a Netplan V1 config from MAAS on the running VMware ESXi "
            "6+ system"
        )
    )
    parser.add_argument(
        "-c", "--config", help="Path to the netplan YAML file", required=True
    )
    subparsers = parser.add_subparsers(
        title="Available commands", dest="subcommand"
    )
    subparsers.add_parser(
        "help",
        description="Show this help message",
        help="Show this help message",
    )
    subparsers.add_parser(
        "apply",
        description="Apply the specified netplan config to running system",
        help="Apply specified netplan config to running system",
    )
    subparsers.add_parser(
        "wipe",
        description="Wipe the current network configuration on the system.",
        help="Wipe the current network configuration on the system.",
    )

    args = parser.parse_args()

    with open(args.config, "r") as f:
        config = yaml.safe_load(f)

    # Allows either a full Curtin config or just the network section to be
    # passed.
    if "network" in config:
        config = config["network"]

    if config.get("version") != 1:
        print("ERROR: Only V1 config is supported!", file=sys.stderr)
        sys.exit(os.EX_CONFIG)

    if args.subcommand == "help":
        help(parser)
    elif args.subcommand == "apply":
        # Start with a clean config
        wipe()
        apply(config["config"])
    elif args.subcommand == "wipe":
        wipe()
    else:
        warn("You must specify a command")
        help(parser)


if __name__ == "__main__":
    main()
