import json
import logging

import requests
from google.cloud.tasks_v2 import (
    AppEngineHttpRequest,
    AppEngineRouting,
    CloudTasksClient,
    CreateTaskRequest,
    HttpMethod,
    Task,
)

from ._common import Entity, Key, Metrics, beam, bigquery, get_datastore_client

logger = logging.getLogger(__name__)

ITEMS = "items"


class TaggedOutputDoFn(beam.DoFn):
    OUTPUT_TAG_SUCCESS = "success"
    OUTPUT_TAG_UPDATES = "updates"
    OUTPUT_TAG_SKIPPED = "skipped"
    OUTPUT_TAG_ERRORS = "errors"


class Counter(beam.DoFn):
    """A basic `DoFn` to count the number of items in the `PCollection` using
    Metrics.

    https://beam.apache.org/documentation/programming-guide/#metrics
    """

    default_name = ITEMS

    def __init__(self, namespace=None, name=None):
        super().__init__()

        cls = self.__class__

        self.counter = Metrics.counter(
            namespace or cls,
            name or cls.default_name,
        )

    def process(self, element):
        if element:
            self.counter.inc()

        yield element


class DuplicateEntity(TaggedOutputDoFn):
    """A `beam.DoFn` that accepts a Google Datastore entity and duplicates it by
    copying & modifying the key via a provided `key_transformer`. The duplicate
    entity's properties may also be changed if a `props_formatter` is provided.

    NOTE: The `key_transformer` must add an id/name because the duplicated
    entity is written via `WriteToDatastore()`, which cannot handle partial
    keys.

    Usage:
        ```py
        # in DataflowJob.expand()
        # ...
        def change_namespace(orig_key):
            return Key(
                orig_key.kind,
                orig_key.id,
                project=orig_key.project,
                namespace="foo-bar",
            )

        def transform_data(entity):
            return {
                "foo": entity["bar"]
            }

        # ...
        | UpdateEntities(DuplicateEntity(
            change_namespace,
            props_formatter=transform_data
        ))
        ```

    Args:
        key_transformer (callable): Callable that accepts the original/source
            key (`google.cloud.datastore.key.Key`) and returns a new key with
            project, namespace, kind, and id/name set accordingly.
        props_formatter (callable, optional): Callable that accepts an entity's
            `properties` and returns a `dict`. Defaults to None, which means we
            use `properties` attribute as-is.

    Returns:
        beam.pvalue.TaggedOutput(apache_beam.io.gcp.datastore.v1new.types.Entity)
    """

    def __init__(self, key_transformer, props_formatter=None):
        if not key_transformer:
            raise RuntimeError("A key_transformer must be supplied!")

        super().__init__()

        self.key_transformer = key_transformer
        self.props_formatter = props_formatter

    # https://beam.apache.org/releases/pydoc/current/apache_beam.io.gcp.datastore.v1new.types.html#apache_beam.io.gcp.datastore.v1new.types.Entity
    def process(self, src_entity):
        src_ds_key = src_entity.key.to_client_key()
        src_hr_key = to_hr_key(src_ds_key, abs_path=True)
        logger.debug(
            "Processing %s with properties %s",
            src_hr_key,
            sorted(src_entity.properties.keys()),
        )

        source_hr_key = to_hr_key(src_ds_key)

        try:
            dest_ds_key = self.key_transformer(src_ds_key)
        except Exception as err:
            logger.error(
                "Failed transforming key for %s (%s)",
                source_hr_key,
                f"{err.__class__.__name__}: {str(err)}",
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_ERRORS, src_entity)]

        dest_key = Key.from_client_key(dest_ds_key)
        dest_entity = Entity(key=dest_key)

        try:
            dest_entity_dict = entity_to_dict(src_entity, self.props_formatter)
        except Exception as err:
            logger.error(
                "Failed formatting properties for %s (%s)",
                source_hr_key,
                f"{err.__class__.__name__}: {str(err)}",
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_ERRORS, src_entity)]

        dest_props_dict = dest_entity_dict["properties"]
        logger.debug("Destination props/values %s", dest_props_dict)

        dest_entity.set_properties(dest_props_dict)

        abs_path = (src_ds_key.namespace != dest_ds_key.namespace) or (
            src_ds_key.project != dest_ds_key.project
        )
        logger.info(
            "Successfully duplicated entity %s to %s",
            to_hr_key(src_ds_key, abs_path),
            to_hr_key(dest_ds_key, abs_path),
        )

        return [dest_entity]


class SkipEntityIf(TaggedOutputDoFn):
    """A `beam.DoFn` that accepts a Google Datastore entity and skips further
    processing (sent to the 'skipped' tagged output stream) if the provided
    `test_condition` (callable) is `True`.

    If an error occurs in `test_condition`, the entity is sent to the `errors`
    tagged output stream.

    Usage:
        ```py
        # in DataflowJob.expand()
        # ...
        def should_skip(entity):
            return bool(entity["foo"] == "bar)

        # ...
        | beam.ParDo(SkipEntityIf(options, should_skip))
        ```

    Args:
        options (PipelineOptions): This pipeline's options.
        test_condition (callable): Callable that accepts a
            `google.cloud.datastore.entity.Entity` and returns `True` if the
            entity should be skipped.

    Returns:
        beam.pvalue.TaggedOutput(apache_beam.io.gcp.datastore.v1new.types.Entity)
    """

    def __init__(self, options, test_condition):
        super().__init__()

        self.options = options
        self.test_condition = test_condition

    def process(self, entity, **kwargs):
        entity_ds_key = entity.key.to_client_key()
        client_entity = entity.to_client_entity()
        entity_hr_key = to_hr_key(entity_ds_key)

        try:
            skip = self.test_condition(client_entity)
        except Exception as err:
            logger.error(
                "Failed executing %s for %s (%s); Discarding",
                self.test_condition.__name__,
                entity_hr_key,
                format_error(err),
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_ERRORS, entity)]

        if skip:
            logging.info(
                "Skipping {} per {}".format(entity_hr_key, self.test_condition.__name__)
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_SKIPPED, entity)]

        return [entity]


class CallApiWithEntity(TaggedOutputDoFn):
    """A `beam.DoFn` that uses `requests` to make an API call using `requests`
    for the passed entity. The API call is a POST, but can be customized by
    overriding the `call_api()` method.

    Exceptions are handled by yielding to two additional tagged outputs (skipped
    and errors) for compatibility with the `ExecuteFnForEntities` transform.

    Usage:
        ```py
        # in DataflowJob.expand()
        # ...
        def compute_url(_entity):
            return "https://baz/api/migrate-foo"

        def prepare_data(entity):
            return {
                "foo": entity["foo"],
            }

        # ...
        | beam.ParDo(CallApiWithEntity(options, compute_url, prepare_data))
        ```

    Args:
        options (PipelineOptions): This pipeline's options.
        compute_url (callable): Callable that accepts a
            `google.cloud.datastore.entity.Entity` and returns a URL as a `str`.
        prepare_data (callable): Callable that accepts a
            `google.cloud.datastore.entity.Entity` and returns a suitable value
            for the data used in the API call.
        dry_run (bool, optional): Calls the API if True. Defaults to False.

    Returns:
        beam.pvalue.TaggedOutput(apache_beam.io.gcp.datastore.v1new.types.Entity)
    """

    def __init__(self, options, compute_url, prepare_data, headers=None):
        super().__init__()

        self.options = options
        self.compute_url = compute_url
        self.prepare_data = prepare_data
        self.headers = headers or {}

    def process(self, entity, dry_run=False, **kwargs):
        entity_ds_key = entity.key.to_client_key()
        client_entity = entity.to_client_entity()
        entity_hr_key = to_hr_key(entity_ds_key)
        logger.debug(
            "Processing %s with properties %s",
            entity_hr_key,
            sorted(entity.properties.keys()),
        )

        try:
            url = self.compute_url(client_entity)
            logger.debug("Computed URL %s", url)
        except Exception as err:
            logger.error(
                "Failed computing URL for %s (%s); "
                "Skipping API call and further processing for entity.",
                entity_hr_key,
                format_error(err),
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_SKIPPED, entity)]

        try:
            data = self.prepare_data(client_entity)
            logger.debug("Prepared data:\n%s", data)

        except Exception as err:
            logger.error(
                "Failed preparing data for %s (%s); "
                "Skipping API call and further processing for entity.",
                entity_hr_key,
                format_error(err),
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_SKIPPED, entity)]

        if dry_run:
            logger.warning(
                "This is a DRY RUN; "
                "Skipping API call and pretending it succeeded by returning entity."
            )
            return entity

        # in case call_api() fails entirely
        resp_data = None

        try:
            resp = self.call_api(url, data)

            try:
                resp_data = resp.json()
            except requests.JSONDecodeError:
                resp_data = resp.text

            logger.debug(
                "HTTP %s for %s API response:\n%s",
                resp.status_code,
                entity_hr_key,
                resp_data,
            )

            if resp.status_code >= 500:
                # raises RequestException, handled by outer `try`
                resp.raise_for_status()

            if resp.status_code >= 400:
                logger.error(
                    "API call failed for %s due to client error (HTTP %s); "
                    "Skipping entity.\n%s",
                    entity_hr_key,
                    resp.status_code,
                    resp_data,
                )
                return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_SKIPPED, entity)]

        except (
            requests.RequestException,
            requests.ConnectionError,
            requests.ConnectTimeout,
            requests.ReadTimeout,
            requests.Timeout,
            requests.TooManyRedirects,
            requests.HTTPError,
        ) as err:
            logger.error(
                "API call failed for %s (%s); Skipping entity.\n%s",
                entity_hr_key,
                format_error(err),
                resp_data,
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_ERRORS, entity)]

        logger.info("Successfully called API for %s", entity_hr_key)

        return [entity]

    def call_api(self, url, data):
        # override with whatever is required for your API call.
        return requests.post(url, json=data, headers=self.headers)


class CreateTaskWithEntity(TaggedOutputDoFn):
    """A `beam.DoFn` that enqueues a Google Cloud Task
    (`google.cloud.tasks_v2.Task`) of type `AppEngineHttpRequest` for the passed
    entity.

    The `gct_project` and `gct_region` options are honored.

    Usage:
        ```py
        # in DataflowJob.expand()
        # ...
        def prepare_data(entity):
            return {
                "foo": entity["foo"]
            }

        # ...
        | ExecuteFnForEntities(
                options,
                entity_handler=CreateTaskWithEntity(
                    options,
                    queue_name="bar-service-migrate-foo",
                    service="bar-service",
                    version="migrate-foo",
                    path="/api/foo",
                    prepare_data=prepare_data,
                    headers={
                        "x-baz": "quux",
                    },
                ),
            )
        ```

    Args:
        options (PipelineOptions): This pipeline's options.
        queue_name (str): Name of Google Cloud Task queue onto which `Task` will
            be enqueued.
        service (str): Name of App Engine service that `Task` will use.
        version (str): Version name of App Engine service that `Task` will use.
        path (str): Root-relative path (target) that `Task` will use.
        prepare_data (callable): Callable that accepts a
            `google.cloud.datastore.entity.Entity` and returns a
            JSON-serializable `Task` payload.
        headers (dict): HTTP headers used by `Task` when calling service.
        dry_run (bool, optional): Calls the Task API if True. Defaults to False.

    Returns:
        beam.pvalue.TaggedOutput(apache_beam.io.gcp.datastore.v1new.types.Entity)
    """

    def __init__(
        self,
        options,
        queue_name,
        service,
        version,
        path,
        prepare_data,
        headers=None,
        dry_run=False,
    ):
        super().__init__()

        self.options = options
        self.queue_name = queue_name
        self.path = path
        self.service = service
        self.version = version
        self.prepare_data = prepare_data
        self.headers = headers or {}
        self.dry_run = dry_run

    def process(self, entity, *args, **kwargs):
        entity_ds_key = entity.key.to_client_key()
        client_entity = entity.to_client_entity()
        entity_hr_key = to_hr_key(entity_ds_key)

        try:
            payload = self.prepare_data(client_entity)
            logger.debug("Prepared data:\n%s", payload)

        except Exception as err:
            logger.error(
                "Failed preparing data for %s (%s); Entity processing aborted.",
                entity_hr_key,
                format_error(err),
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_SKIPPED, entity)]

        if self.dry_run:
            logger.info(
                "[DRY RUN] Enqueued task to %s in project %s for: %s\n%s",
                self.queue_name,
                self.options.gct_project,
                entity_hr_key,
                to_json(payload, pretty=True),
            )
            return [entity]

        queue_path = CloudTasksClient.queue_path(
            project=self.options.gct_project,
            location=self.options.gct_region,
            queue=self.queue_name,
        )
        logger.debug("Using queue_path: %s", queue_path)

        try:
            req = CreateTaskRequest(parent=queue_path, task=self._create_task(payload))
            task = CloudTasksClient().create_task(req)

        except Exception as err:
            logger.error(
                "Failed adding task for %s (%s); Entity processing aborted.",
                entity_hr_key,
                format_error(err),
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_ERRORS, entity)]

        logger.info(
            "Enqueued task:%s to %s for %s\n%s",
            task.name,
            queue_path,
            entity_hr_key,
            to_json(payload, pretty=True),
        )

        return [entity]

    def _create_task(self, payload):
        return Task(
            app_engine_http_request=AppEngineHttpRequest(
                app_engine_routing=AppEngineRouting(
                    service=self.service,
                    version=self.version,
                ),
                http_method=HttpMethod.POST,
                headers=self.headers,
                relative_uri=self.path,
                body=to_json(payload).encode(),
            ),
        )


class GetEntityFromRow(TaggedOutputDoFn):
    """A `beam.DoFn` that constructs a Datastore key from a column/key from the
    provided BQ `row` (or any `dict`-like data), retrieves the `Entity` from
    Datastore, and yields it to the next step.

    The `__key__` column/key is used by default and we expect its value to have
    `kind`, `id`, and `name` keys. If the column/key only contains an `id` or
    `name` key, we can instantiate this function with a `kind` arg.

    If the column/key contains any other kind of value, we must supply our own
    `key_from_row` callable to parse the value and return a
    `google.cloud.datastore.key.Key`. Any errors raised by this callable will
    send the row to the 'errors' tagged output.

    If the entity referenced by the generated key does not exist, the row will
    be sent to the 'skipped' tagged output.

    Usage:
        * `foo` column contains key structure
            ```py
            # ...
            # in dataflow job
            # ...
            GetDatastoreEntityFromRow(options, key_col="foo")
            ```

        * 'bar' column contains ndb id/name of `Bar` entities
            ```py
            def key_from_row(row, ds_client, col, kind):
                return ds_client.key(kind, row[col])

            # ...
            # in dataflow job
            GetDatastoreEntityFromRow(
                options,
                key_col="bar",
                key_from_row=key_from_row,
                kind="Bar"
            )
            ```

    Args:
        options (PipelineOptions): This pipeline's options.
        key_col (str): Name of column/key containing Datastore key in provided
            `row` (or dict-like structure).
        key_from_row (callable, optional): Callable that accepts `row`,
            `ds_client`, `col`, and `kind` args and returns a key
            (`google.cloud.datastore.key.Key`).
        kind (str):
        pair_with_row (bool, Optional): If True, return the input row as the
            second field of a tuple. Defaults to False.

    Returns:
        beam.pvalue.TaggedOutput(apache_beam.io.gcp.datastore.v1new.types.Entity)
    """

    DEFAULT_KEY_COL = "__key__"

    def __init__(
        self,
        options,
        key_col=None,
        key_from_row=None,
        kind=None,
        pair_with_row=False,
    ):
        super().__init__()

        self.options = options
        self.project = options.datastore_project
        self.namespace = options.datastore_namespace

        self.key_col = key_col or self.DEFAULT_KEY_COL
        self.key_from_row = key_from_row
        self.kind = kind
        self.pair_with_row = pair_with_row

    def process(self, row):
        ds_client = get_datastore_client(project=self.project, namespace=self.namespace)

        key_from_row = self.key_from_row or self._key_from_row
        try:
            entity_ds_key = key_from_row(
                row,
                ds_client,
                col=self.key_col,
                kind=self.kind,
            )
        except Exception as err:
            logger.error(
                "Failed extracting key from column %s (%s)",
                self.key_col,
                format_error(err),
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_ERRORS, row)]

        entity_hr_key = to_hr_key(entity_ds_key)

        ds_entity = ds_client.get(entity_ds_key)
        if not ds_entity:
            logger.warning(
                "Skipping %s; Entity does not exist in Datastore!",
                entity_hr_key,
            )
            return [beam.pvalue.TaggedOutput(self.OUTPUT_TAG_SKIPPED, row)]

        logger.debug(
            "Retrieved entity %s with properties %s",
            entity_hr_key,
            sorted(ds_entity.keys()),
        )

        entity = Entity.from_client_entity(ds_entity)

        if self.pair_with_row:
            return [(entity, row)]

        return [entity]

    @staticmethod
    def _key_from_row(row, ds_client, col, kind):
        key_struct = row[col]

        try:
            key_kind = key_struct["kind"]
        except KeyError:
            key_kind = kind

        try:
            key_id_or_name = key_struct["id"] or key_struct["name"]
        except KeyError:
            key_id_or_name = key_struct

            if not key_kind:
                raise ValueError(f"Must specify `kind` if using simple value in {col}")

        logging.debug(
            "Extracted key_kind=%s, key_id_or_name=%s (%s) from col=%s",
            key_kind,
            key_id_or_name,
            key_id_or_name.__class__.__name__,
            col,
        )

        return ds_client.key(key_kind, key_id_or_name)


# ---


def build_bq_schema(fields):
    # https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/cookbook/bigquery_schema.py
    schema = bigquery.TableSchema()

    for f in fields:
        s = bigquery.TableFieldSchema()
        s.name = f["name"]
        s.type = f["type"]
        s.mode = f["mode"]

        schema.fields.append(s)

    return schema


def key_to_dict(key):
    """Returns a dict for the provided entity Key.

    Args:
        key (apache_beam.io.gcp.datastore.v1new.types.Key): Datastore key in
            Dataflow step.

    Returns:
        dict: keys/values of Key
    """
    return {
        "project": key.project,
        "namespace": key.namespace,
        "parent": key.parent,
        "path_elements": key.path_elements,
    }


def entity_to_dict(ent, props_formatter=None):
    """Returns a `dict` that represents the provided Google Datastore entity.

    Args:
        ent (apache_beam.io.gcp.datastore.v1new.types.Entity): Datastore entity
            in Dataflow step.
        props_formatter (callable, optional): Callable to format the entity's
            properties. Defaults to None, which means we use the `properties`
            attribute as-is.

    Returns:
        dict: A dict with `key` (datastore key as `dict`) and `prop` (entity's
        properties, also a `dict`).
    """

    return {
        "key": key_to_dict(ent.key),
        "properties": (
            props_formatter(ent.properties) if props_formatter else ent.properties
        ),
    }


def serialize_to_json(obj, default=str, **kwargs):
    return json.dumps(obj, default=default, **kwargs)


# from genservice.util; duplicated here to avoid importing genservice because
# doing so has side effects intended only for app engine.
def to_json(data, pretty=False, **kwargs):
    if pretty:
        kwargs["indent"] = 2

    return serialize_to_json(data, **kwargs)


def serialize_dict_values(d):
    """Returns a `dict` with the same keys, but with its values serialized to
    JSON.

    Used for making a BQ column out of a Google Datastore entity dict, with its
    properties as JSON.

    Args:
        d (dict): The dict to serialize

    Returns:
        dict: keys with each value serialized to JSON
    """

    return {k: serialize_to_json(v) for k, v in d.items()}


def to_hr_key(key, abs_path=False):
    """Returns a so-called 'human-readable' key: `kind:id_or_name`.

    The key's project and namespace can be included as a prefix by passing
    `abs_path=True`.

    Args:
        key (google.cloud.datastore.key.Key): The key to format. (Also accepts a
            apache_beam.io.gcp.datastore.v1new.types.Key for convenience.)
        abs_path (bool): Includes the project and namespace as a prefix if True.
            Defaults to False.

    Returns:
        string: The human-readable version of the passed ndb key.
    """

    # also handle apache_beam.io.gcp.datastore.v1new.types.Key too because
    # keeping the types straight is painful.
    if isinstance(key, Key):
        key = key.to_client_key()

    hr_key = f"{key.kind}:{key.id_or_name}"

    if abs_path:
        hr_key = f"{key.project}:{key.namespace}:{hr_key}"

    return hr_key


def format_error(error):
    return f"{error.__class__.__name__}{str(error)}"