#!/usr/bin/env python3
#
# storage-esxi - Apply storage configuration from MAAS on VMware ESXi
#
# Author: Lee Trager <lee.trager@canonical.com>
#
# Copyright (C) 2019-2024 Canonical
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import re
import sys
import json
from argparse import ArgumentParser
from functools import lru_cache
from subprocess import PIPE, Popen, check_call, check_output

import yaml


def info(msg):
    """Print a standard info message."""
    print("INFO: %s" % msg)


def warn(msg):
    """Print a standard warning message."""
    print("WARNING: %s" % msg, file=sys.stderr)


def error(msg):
    """Print a standard error message."""
    print("ERROR: %s" % msg, file=sys.stderr)


@lru_cache(maxsize=1)
def get_esxi_version():
    """Return the VMware ESXi version as a tuple."""
    ret = []
    output = check_output(["uname", "-r"]).decode()
    for i in output.split("."):
        # The version is currently a tuple of ints. Allow strings
        # incase VMware decides to add them.
        try:
            i = int(i)
        except ValueError:
            pass
        ret.append(i)
    return tuple(ret)


def is_esxi67():
    """Return true if running on VMware ESXi 6.7.X."""
    return get_esxi_version()[0:2] == (6, 7)


def has_esx_os_data():
    """Return true if running on VMware ESXi 7.X.Y. or newer"""
    return get_esxi_version()[0] >= 7

def get_nvme_serial_numbers():
    """Parse all NVME disks on the system to
       generate a map of Serial Numbers to Adapter Names
    """
    nvmes = []

    output = check_output(
        ["esxcli", "nvme", "device", "list"]
    ).decode()
    for deviceline in output.splitlines():
        nvme = None
        if deviceline.startswith("vmh"):
           if deviceline.split()[0]:
              nvme = {}
              nvmehba = deviceline.split()[0]
              signature = deviceline.split()[2]
              # Attach the adapter name to each nvme drive
              nvme['adapter'] = nvmehba
              # Convert nvme names to Linux equivalents
              nvme['devname'] = f"nvme{signature[-1]}n1"
              hbaout = check_output(
                  ["esxcli", "nvme", "device", "get", "-A", nvmehba]
              ).decode()
              for hbaline in hbaout.splitlines():
                  if hbaline.strip().startswith("Serial Number:"):
                        nvme["serial"] = hbaline.strip().split(":")[1].strip()

              if nvme:
                  nvmes.append(nvme)
    return nvmes

def extract_serial_from_other_names(line):
    """
    Extracts the serial number from a t10.NVMe line using '_' as delimiter.
    Assumes the serial number is the last alphanumeric field before any trailing numbers or blanks.
    """
    if not line.startswith("t10."):
        return ""

    match = re.search(r"t10\.NVMe____.*?_{2,}(?P<serial>[A-Z0-9]+)", line)
    if match:
        return match.group("serial")

    # Fallback for lines with a single underscore separator like Dell drives
    match = re.search(r"t10\.NVMe____.*?_(?P<serial>[A-Z0-9]+)(?:_+|$)", line)
    if match:
        return match.group("serial")

    return False

def get_disks(nvmes):
    """Parse all disks on the system."""
    disks = []

    disk_block_sizes = {}
    output = check_output(
        ["esxcli", "storage", "core", "device", "capacity", "list"]
    ).decode()
    for line in output.splitlines()[2:]:
        line = line.split()
        disk_block_sizes[line[0]] = int(line[1])

    # Get the list of devices
    output = check_output(["esxcfg-scsidevs", "-l"]).decode()
    disk = None
    other_names = False

    # Regex to read name, model, and serial numbers
    serial_regex = re.compile(
        r"^(?P<name>t10\.(?:ATA|NVMe)____)"    # The device names (either ATA or NVMe)
        r"(?P<model>[A-Z0-9_]+?)__+"           # The model names until multiple underscores
        r"(?P<serial>[A-Z0-9]+)_*$"            # The serial numbers after multiple underscores
    )

    vendor_model_regex = re.compile(
        r"^Vendor:\s+(?P<vendor>.*)\s*"
        r"Model:\s+(?P<model>.*)\s*"
        r"Revis:\s+(?P<revis>.*)\s*$"
    )

    # Parse the device list
    for line in output.splitlines():
        if not line.startswith(" " * 3):
            # Each section starts with the device name,
            # all fields are defined below and start with 3 spaces.
            if disk:
                disks.append(disk)
                other_names = False
            name = line.strip()
            m = serial_regex.search(name)
            disk = {
                "name": name,
                "other_names": [],
                "blocksize": disk_block_sizes.get(name, 0),
                "serial": m.group("serial") if m else extract_serial_from_other_names(name),
                "model": m.group("model") if m else "",
                "path": "",
                "adapter": ""
            }
        elif disk:
            # Other names is a list of alias.
            # Entries must start with 6 spaces.
            if other_names and not line.startswith(" " * 6):
                other_names = False
            line = line.strip()
            if line.startswith("Size"):
                _, size = line.split(":", 1)
                disk["size"] = size.strip()
            elif line.startswith("Devfs Path"):
                _, path = line.split(":", 1)
                disk["path"] = path.strip()
            elif line.startswith("Vendor"):
                m = vendor_model_regex.search(line)
                if m:
                    disk["model"] = m.group("model").strip()
            elif line.startswith("Other Names"):
                other_names = True
            elif other_names:
                disk["other_names"].append(line)
                on_serial = extract_serial_from_other_names(line)
                if on_serial and not disk["serial"]:
                    disk["serial"] = on_serial

    if disk:
        disks.append(disk)

    # Map adapter names using serial numbers
    output = check_output(["esxcli", "storage", "core", "path", "list"]).decode()

    current_adapter = None
    matching_serial = ""

    for line in output.splitlines():
        line = line.strip()

        if line.startswith("Adapter:"):
            _, adapter = line.split(":", 1)
            current_adapter = adapter.strip()

        elif line.startswith("Device Display Name:"):
            for disk in disks:
                if disk['serial'] in line:
                    matching_serial = disk['serial']
                    break

        elif line.startswith("Controller:"):
            for disk in disks:
                if disk['serial'] in line:
                    disk['adapter'] = current_adapter
                    break
                elif disk['adapter'] == "":
                        for nvme in nvmes:
                            if nvme['serial'] in line and disk['serial'] == matching_serial and nvme["adapter"] == current_adapter:
                                disk['adapter'] = current_adapter
                                # Override Disk's serial number with the one we get from NVME
                                # For NVME disks, serial numbers reported in Linux and ESXi may not match
                                disk['serial'] = nvme['serial']
                                break

    return disks

def get_disk(disks, nvmes, serial):
    """Return the disk matching the model and serial from the list of disk."""
    for disk in disks:
        if serial == disk["serial"]:
            return disk

    return None


def parse_config(config):
    """Pulls the disks out of the config and map to VMware devices."""
    disks = {}
    partitions = {}
    vmfs_datastores = {}

    nvmes = get_nvme_serial_numbers()
    detected_disks = get_disks(nvmes)

    # Log config data
    with open("/var/log/storage-esxi-config.log", "w") as log_file:
        json.dump(config, log_file, indent=2)

    # Log detected disks data
    with open("/var/log/storage-esxi-disks.log", "w") as log_file:
        json.dump(detected_disks, log_file, indent=2)

    with open("/var/log/storage-esxi-nvmes.log", "w") as log_file:
        json.dump(nvmes, log_file, indent=2)

    for i in config:
        if i.get("type") == "disk":
            if "grub_device" in i:
                i["partitioned"] = True
            model = i["model"].replace(" ", "_")
            serial = i["serial"].replace(" ", "_")

            # Log parsed config data
            with open("/var/log/storage-esxi-parse-config.log", "a") as log_file:
                json.dump(i, log_file, indent=2)

            disk = get_disk(detected_disks, nvmes, serial)
            if disk:
                i["path"] = disk["path"]
                i["blocksize"] = disk["blocksize"]
                disks[i["id"]] = i
            else:
                warn("Disk %s %s not found!" % (i["model"], i["serial"]))
        elif i.get("type") == "partition":
            partitions[i["id"]] = i
        elif i.get("type") == "vmfs6":
            vmfs_datastores[i["id"]] = i

    # Log parsed disks data
    with open("/var/log/storage-esxi-disks-parsed.log", "w") as log_file:
        json.dump(disks, log_file, indent=2)

    return disks, partitions, vmfs_datastores


def process_disk_wipes(disks):
    """Process wiping the disks."""
    for disk in disks.values():
        if disk.get("grub_device"):
            # The grub_device is the disk ESXi was installed to and is
            # currently running on. Fix the partition table to use the
            # full size of the disk.
            p = Popen(["partedUtil", "fixGpt", disk["path"]], stdin=PIPE)
            p.communicate(input=b"Y\nFix\n")
            # Remove the default datastore partition if it exists. It is the
            # only partition included in the base layout that can be
            # customized.
            if is_esxi67() and os.path.exists("%s:3" % disk["path"]):
                datastore_partition = "3"
            elif has_esx_os_data() and os.path.exists("%s:8" % disk["path"]):
                datastore_partition = "8"
            else:
                datastore_partition = None
            if datastore_partition is not None:
                info("Removing the default datastore.")
                check_call(
                    ["esxcli", "storage", "filesystem", "unmount", "-a"]
                )
                check_call(
                    ["partedUtil", "delete", disk["path"], datastore_partition]
                )
                check_call(["esxcli", "storage", "filesystem", "rescan"])
            continue
        wipe = disk.get("wipe")
        if not wipe:
            continue
        info("Wiping %s using the %s algorithm." % (disk["path"], wipe))
        cmd = [
            "dd",
            "conv=notrunc",
            "of=%s" % disk["path"],
            "bs=%s" % disk["blocksize"],
        ]
        if wipe == "superblock":
            check_call(cmd + ["if=/dev/zero", "count=1"])
        elif wipe == "zero":
            check_call(cmd + ["if=/dev/zero"])
        elif wipe == "random":
            check_call(cmd + ["if=/dev/urandom"])
    check_call(["sync"])


def get_starting_sector(path):
    """Return the starting sector for a partition."""
    output = check_output(["partedUtil", "getptbl", path]).decode()
    starting_sector = 0
    for line in output.split("\n"):
        line = line.split()
        # The first line is the partition table type
        # The second line is 4 columns and contains the total disk size
        # All other lines give the partition information.
        if (len(line) == 4 and starting_sector == 0) or len(line) == 6:
            starting_sector = max(starting_sector, int(line[2]))

    return starting_sector + 1


def get_ending_sector(path):
    """Return the ending sector from the disk path."""
    ending_sector = 0
    part_info = check_output(["partedUtil", "get", path]).decode().splitlines()[0]
    ending_sector = int(part_info.split()[3]) - int(part_info.split()[0])

    return ending_sector

def partition_disks(disks, partitions):
    """Partition all disks."""
    # See https://kb.vmware.com/s/article/1036609
    for part in partitions.values():
        disk = disks[part["device"]]
        # The grub_device is the disk which Curtin installed the OS to. The
        # offical VMware ESXi 6.7 installer defines 8 partitions and skips
        # partition 4. Partition 3 is the datastore which can be extended.
        # It needs to be recreated. On VMware ESXi 7.0 the offical installer
        # defines 5 partitions and skips partitions 2-4. Partition 8 is the
        # datastore which can be extended.
        #
        # When vmfs6 storage layout is selected, MAAS nonetheless sets
        # partition 3 for datastores on grub_device.
        #
        # On non-grub-device disks, the datastore remains partition 1.
        if disk.get("grub_device") and (
            (is_esxi67() and part["number"] != 3 and part["number"] <= 9)
            or (has_esx_os_data() and part["number"] != 3)
        ):
            continue
        elif not disk.get("partitioned"):
            info(
                "Creating a %s partition table on %s"
                % (disk["ptable"], disk["path"])
            )
            check_call(["partedUtil", "mklabel", disk["path"], disk["ptable"]])
            disk["partitioned"] = True

        if not os.path.exists("%s:%s" % (disk["path"], part["number"])):
            starting_sector = get_starting_sector(disk["path"])
            ending_sector = get_ending_sector(disk["path"])

            info(
                 "Creating partition %s on %s (Start: %s End: %s)"
                 % (part["number"], disk["path"], starting_sector, ending_sector)
                )

            check_call(
                [
                    "partedUtil",
                    "add",
                    disk["path"],
                    disk["ptable"],
                    # partedUtil expects this as one argument
                    "%s %s %s AA31E02A400F11DB9590000C2911D1B8 0"
                    % (part["number"], starting_sector, ending_sector),
                ]
            )
        else:
            info(
                 "Skip creating partition %s on %s"
                 % (part["number"], disk["path"])
            )


def get_partition_dev(disks, partitions, id):
    """Convert a partition id into a device path."""
    partition = partitions[id]
    disk = disks[partition["device"]]
    return "%s:%s" % (disk["path"], partition["number"])


def mkvmfs(disks, partitions, vmfs_datastores):
    """Create the defined VMFS datastores."""
    for vmfs_datastore in vmfs_datastores.values():
        head_partition = get_partition_dev(
            disks, partitions, vmfs_datastore["devices"][0]
        )
        # Both VMware ESXi uses VMFS6.
        info(
            "Creating VMFS6 datastore %s using %s as the head partition"
            % (vmfs_datastore["name"], head_partition)
        )
        check_call(
            [
                "vmkfstools",
                "-C",
                "vmfs6",
                "-S",
                vmfs_datastore["name"],
                head_partition,
            ]
        )
        for extent in vmfs_datastore["devices"][1:]:
            extent_dev = get_partition_dev(disks, partitions, extent)
            info(
                "Adding %s as an extent to VMFS6 datastore %s"
                % (extent_dev, vmfs_datastore["name"])
            )
            p = Popen(
                ["vmkfstools", "-Z", extent_dev, head_partition], stdin=PIPE
            )
            p.communicate(input=b"0\n")


def extend_default(disks):
    """Extend the default datastore if no VMFS config is given."""
    dev_path = None
    part_num = 0
    part_start = 0
    part_end = 0
    last_end = 0
    volumes = check_output(["esxcli", "storage", "vmfs", "extent", "list"])
    extend_vmfs = True
    for volume in volumes.decode().splitlines():
        volume = volume.split()
        if volume[0] == "datastore1":
            dev_path = "/vmfs/devices/disks/%s" % volume[3]
            part_num = volume[4]
            break
    if not dev_path:
        # For whatever reason VMware ESXi will remove defined datastores on
        # deployment on some hardware. Assume that is what is happening.
        for disk in disks.values():
            if disk.get("grub_device", False):
                dev_path = disk["path"]
                part_num = "3" if is_esxi67() else "8"
                extend_vmfs = False
                break

    p = Popen(["partedUtil", "fixGpt", dev_path], stdin=PIPE)
    p.communicate(input=b"Y\nFix\n")

    # Get the sector the partition currently starts on.
    part_info = check_output(["partedUtil", "get", dev_path])
    for part in part_info.decode().splitlines():
        if extend_vmfs and part.startswith("%s " % part_num):
            part_start = part.split()[1]
            break
        else:
            last_end = part.split()[2]

    # Get the last sector of the disk to extend the datastore to.
    part_info = check_output(["partedUtil", "getUsableSectors", dev_path])
    part_end = part_info.decode().split()[1]

    vmfs_part = "%s:%s" % (dev_path, part_num)

    if extend_vmfs:
        check_call(
            ["partedUtil", "resize", dev_path, part_num, part_start, part_end]
        )
        check_call(["vmkfstools", "--growfs", vmfs_part, vmfs_part])
    else:
        part_start = str(int(last_end) + 1)
        check_call(
            [
                "partedUtil",
                "add",
                disk["path"],
                disk["ptable"],
                # partedUtil expected this as one argument
                "%s %s %s AA31E02A400F11DB9590000C2911D1B8 0"
                % (part_num, part_start, part_end),
            ]
        )
        check_call(
            ["vmkfstools", "-C", "vmfs6", "-S", "datastore1", vmfs_part]
        )


def main():
    parser = ArgumentParser(
        description=(
            "Apply the MAAS storage configuration to the running "
            "VMware ESXi 6+ system."
        )
    )
    parser.add_argument(
        "-c",
        "--config",
        help="Path to the storage configuration file.",
        required=True,
    )

    args = parser.parse_args()

    with open(args.config, "r") as f:
        config = yaml.safe_load(f)

    # Allows either a full Curtin config or just the storage section to be
    # passed.
    if "storage" in config:
        config = config.get("storage", [])

    if config.get("version") != 1:
        error("Only V1 config is supported!")
        sys.exit(os.EX_CONFIG)

    disks, partitions, vmfs_datastores = parse_config(config["config"])

    if len(vmfs_datastores) == 0:
        # In the lack of storage information, only one disk will be
        # used to create the default datastore1
        warn("No storage information given, extending datastore1.")
        extend_default(disks)
    else:
        process_disk_wipes(disks)
        partition_disks(disks, partitions)
        mkvmfs(disks, partitions, vmfs_datastores)

    info("Done applying storage configuration!")


if __name__ == "__main__":
    main()
