#!/usr/bin/env python3
#
# maas-md-get - Retrieve data from the MAAS metadata-service using OAUTH creds
#
# Author: Lee Trager <lee.trager@canonical.com>
#
# Copyright (C) 2019 Canonical
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import http.client
import socket
import sys
import time
import urllib.error
import urllib.parse
import urllib.request
from argparse import ArgumentParser
from email.utils import parsedate

import oauthlib.oauth1 as oauth
import yaml


def authenticate_headers(url, headers, creds, clockskew=0):
    """Update and sign a dict of request headers."""
    if creds.get("consumer_key") is not None:
        timestamp = int(time.time()) + clockskew
        client = oauth.Client(
            client_key=creds["consumer_key"],
            client_secret=creds.get("consumer_secret", ""),
            resource_owner_key=creds["token_key"],
            resource_owner_secret=creds["token_secret"],
            signature_method=oauth.SIGNATURE_PLAINTEXT,
            timestamp=str(timestamp),
        )
        _, signed_headers, _ = client.sign(url)
        headers.update(signed_headers)


def geturl(url, creds, headers=None, data=None):
    # Takes a dict of creds to be passed through to oauth_headers,
    # so it should have consumer_key, token_key, ...
    if headers is None:
        headers = {}
    else:
        headers = dict(headers)

    clockskew = 0

    error = Exception("Unexpected Error")
    for naptime in (1, 1, 2, 4, 8, 16, 32):
        authenticate_headers(url, headers, creds, clockskew)
        try:
            req = urllib.request.Request(url=url, data=data, headers=headers)
            ret = urllib.request.urlopen(req)
            return ret.read().decode()
        except urllib.error.HTTPError as exc:
            error = exc
            if "date" not in exc.headers:
                print("date field not in %d headers" % exc.code)
                pass
            elif exc.code in (http.client.UNAUTHORIZED, http.client.FORBIDDEN):
                date = exc.headers["date"]
                ret_time = time.mktime(parsedate(date))
                clockskew = int(ret_time - time.time())
                print("updated clock skew to %d" % clockskew)
            elif exc.code == http.client.NOT_FOUND:
                # Nothing to download.
                return ""

        except Exception as exc:
            error = exc

        print("request to %s failed. sleeping %d.: %s" % (url, naptime, error))
        time.sleep(naptime)

    raise error


def main():
    parser = ArgumentParser(
        description="Get data from the MAAS metadata server"
    )
    parser.add_argument(
        "-c",
        "--config",
        help="Path to the MAAS metadata credentials file",
        required=True,
    )
    parser.add_argument("entry", help="The metadata path to get")

    args = parser.parse_args()

    with open(args.config, "r") as f:
        cfg = yaml.safe_load(f)

    if "reporting" in cfg:
        cfg = cfg["reporting"]["maas"]
    url = "%s%s" % (
        cfg["endpoint"][: cfg["endpoint"].find("status/")],
        args.entry,
    )

    try:
        print(geturl(url, creds=cfg))
    except urllib.error.HTTPError as exc:
        print("HTTP error [%s]" % exc.code)
        sys.exit(1)
    except urllib.error.URLError as exc:
        print("URL error [%s]" % exc.reason)
        sys.exit(1)
    except socket.timeout as exc:
        print("Socket timeout [%s]" % exc)
        sys.exit(1)


if __name__ == "__main__":
    main()
