from astropy import units as u
from astropy.coordinates import (
    CartesianRepresentation,
    get_body_barycentric,
)
from astropy.tests.helper import assert_quantity_allclose
from astropy.time import Time
import numpy as np
import pytest

from boinor.bodies import (
    Earth,
    Jupiter,
    Mars,
    Mercury,
    Moon,
    Neptune,
    Saturn,
    Sun,
    Uranus,
    Venus,
)
from boinor.constants import J2000
from boinor.frames.ecliptic import GeocentricSolarEcliptic
from boinor.frames.enums import Planes
from boinor.frames.equatorial import (
    GCRS,
    HCRS,
    ICRS,
    JupiterICRS,
    MarsICRS,
    MercuryICRS,
    NeptuneICRS,
    SaturnICRS,
    UranusICRS,
    VenusICRS,
    _PlanetaryICRS,
)
from boinor.frames.fixed import (
    ITRS,
    JupiterFixed,
    MarsFixed,
    MercuryFixed,
    MoonFixed,
    NeptuneFixed,
    SaturnFixed,
    SunFixed,
    UranusFixed,
    VenusFixed,
    _PlanetaryFixed,
)
from boinor.frames.util import get_frame


@pytest.mark.parametrize(
    "body, frame",
    [
        (Mercury, MercuryICRS),
        (Venus, VenusICRS),
        (Mars, MarsICRS),
        (Jupiter, JupiterICRS),
        (Saturn, SaturnICRS),
        (Uranus, UranusICRS),
        (Neptune, NeptuneICRS),
    ],
)
def test_planetary_frames_have_proper_string_representations(body, frame):
    coords = frame()

    assert body.name in repr(coords)


@pytest.mark.parametrize(
    "body, frame",
    [
        (Sun, HCRS),
        (Mercury, MercuryICRS),
        (Venus, VenusICRS),
        (Earth, GCRS),
        (Mars, MarsICRS),
        (Jupiter, JupiterICRS),
        (Saturn, SaturnICRS),
        (Uranus, UranusICRS),
        (Neptune, NeptuneICRS),
    ],
)
def test_planetary_icrs_frame_is_just_translation(body, frame):
    epoch = J2000
    vector = CartesianRepresentation(x=100 * u.km, y=100 * u.km, z=100 * u.km)
    vector_result = (
        frame(vector, obstime=epoch)
        .transform_to(ICRS())
        .represent_as(CartesianRepresentation)
    )

    expected_result = get_body_barycentric(body.name, epoch) + vector

    assert_quantity_allclose(vector_result.xyz, expected_result.xyz)


@pytest.mark.parametrize(
    "body, frame",
    [
        (Sun, HCRS),
        (Mercury, MercuryICRS),
        (Venus, VenusICRS),
        (Earth, GCRS),
        (Mars, MarsICRS),
        (Jupiter, JupiterICRS),
        (Saturn, SaturnICRS),
        (Uranus, UranusICRS),
        (Neptune, NeptuneICRS),
    ],
)
def test_icrs_body_position_to_planetary_frame_yields_zeros(body, frame):
    epoch = J2000
    vector = get_body_barycentric(body.name, epoch)

    vector_result = (
        ICRS(vector)
        .transform_to(frame(obstime=epoch))
        .represent_as(CartesianRepresentation)
    )

    assert_quantity_allclose(
        vector_result.xyz, [0, 0, 0] * u.km, atol=1e-7 * u.km
    )


@pytest.mark.parametrize(
    "body, fixed_frame, inertial_frame",
    [
        (Sun, SunFixed, HCRS),
        (Mercury, MercuryFixed, MercuryICRS),
        (Venus, VenusFixed, VenusICRS),
        (Earth, ITRS, GCRS),
        (Mars, MarsFixed, MarsICRS),
        (Jupiter, JupiterFixed, JupiterICRS),
        (Saturn, SaturnFixed, SaturnICRS),
        (Uranus, UranusFixed, UranusICRS),
        (Neptune, NeptuneFixed, NeptuneICRS),
    ],
)
def test_planetary_fixed_inertial_conversion(
    body, fixed_frame, inertial_frame
):
    epoch = J2000
    fixed_position = fixed_frame(
        0 * u.deg,
        0 * u.deg,
        body.R,
        obstime=epoch,
        representation_type="spherical",
    )
    inertial_position = fixed_position.transform_to(
        inertial_frame(obstime=epoch)
    )
    assert_quantity_allclose(
        fixed_position.spherical.distance, body.R, atol=1e-7 * u.km
    )
    assert_quantity_allclose(
        inertial_position.spherical.distance, body.R, atol=1e-7 * u.km
    )


@pytest.mark.parametrize(
    "body, fixed_frame, inertial_frame",
    [
        (Sun, SunFixed, HCRS),
        (Mercury, MercuryFixed, MercuryICRS),
        (Venus, VenusFixed, VenusICRS),
        (Earth, ITRS, GCRS),
        (Mars, MarsFixed, MarsICRS),
        (Jupiter, JupiterFixed, JupiterICRS),
        (Saturn, SaturnFixed, SaturnICRS),
        (Uranus, UranusFixed, UranusICRS),
        (Neptune, NeptuneFixed, NeptuneICRS),
    ],
)
def test_planetary_inertial_fixed_conversion(
    body, fixed_frame, inertial_frame
):
    epoch = J2000
    inertial_position = inertial_frame(
        0 * u.deg,
        0 * u.deg,
        body.R,
        obstime=epoch,
        representation_type="spherical",
    )
    fixed_position = inertial_position.transform_to(fixed_frame(obstime=epoch))
    assert_quantity_allclose(
        fixed_position.spherical.distance, body.R, atol=1e-7 * u.km
    )
    assert_quantity_allclose(
        inertial_position.spherical.distance, body.R, atol=1e-7 * u.km
    )


@pytest.mark.parametrize(
    "body, fixed_frame, inertial_frame",
    [
        (Sun, SunFixed, HCRS),
        (Mercury, MercuryFixed, MercuryICRS),
        (Venus, VenusFixed, VenusICRS),
        (Earth, ITRS, GCRS),
        (Mars, MarsFixed, MarsICRS),
        (Jupiter, JupiterFixed, JupiterICRS),
        (Saturn, SaturnFixed, SaturnICRS),
        (Uranus, UranusFixed, UranusICRS),
        (Neptune, NeptuneFixed, NeptuneICRS),
    ],
)
def test_planetary_inertial_roundtrip_vector(
    body, fixed_frame, inertial_frame
):
    epoch = J2000
    sampling_time = 10 * u.s
    fixed_position = fixed_frame(
        np.broadcast_to(0 * u.deg, (1000,), subok=True),
        np.broadcast_to(0 * u.deg, (1000,), subok=True),
        np.broadcast_to(body.R, (1000,), subok=True),
        representation_type="spherical",
        obstime=epoch + np.arange(1000) * sampling_time,
    )
    inertial_position = fixed_position.transform_to(
        inertial_frame(obstime=epoch + np.arange(1000) * sampling_time)
    )
    fixed_position_roundtrip = inertial_position.transform_to(
        fixed_frame(obstime=epoch + np.arange(1000) * sampling_time)
    )
    assert_quantity_allclose(
        fixed_position.cartesian.xyz,
        fixed_position_roundtrip.cartesian.xyz,
        atol=1e-7 * u.km,
    )


def test_round_trip_from_GeocentricSolarEcliptic_gives_same_results():
    gcrs = GCRS(ra="02h31m49.09s", dec="+89d15m50.8s", distance=200 * u.km)
    gse = gcrs.transform_to(GeocentricSolarEcliptic(obstime=Time("J2000")))
    gcrs_back = gse.transform_to(GCRS(obstime=Time("J2000")))
    assert_quantity_allclose(gcrs_back.dec.value, gcrs.dec.value, atol=1e-7)
    assert_quantity_allclose(gcrs_back.ra.value, gcrs.ra.value, atol=1e-7)


def test_GeocentricSolarEcliptic_against_data():
    gcrs = GCRS(ra="02h31m49.09s", dec="+89d15m50.8s", distance=200 * u.km)
    gse = gcrs.transform_to(GeocentricSolarEcliptic(obstime=J2000))
    lon = 233.11691362602866
    lat = 48.64606410986667
    assert_quantity_allclose(gse.lat.value, lat, atol=1e-7)
    assert_quantity_allclose(gse.lon.value, lon, atol=1e-7)


def test_GeocentricSolarEcliptic_raises_error_nonscalar_obstime():
    with pytest.raises(ValueError) as excinfo:
        gcrs = GCRS(ra="02h31m49.09s", dec="+89d15m50.8s", distance=200 * u.km)
        gcrs.transform_to(
            GeocentricSolarEcliptic(obstime=Time(["J3200", "J2000"]))
        )
    assert (
        "To perform this transformation the "
        "obstime Attribute must be a scalar." in str(excinfo.value)
    )


@pytest.mark.parametrize(
    "body, fixed_frame, radecW",
    [
        (Sun, SunFixed, (286.13 * u.deg, 63.87 * u.deg, 84.176 * u.deg)),
        (
            Mercury,
            MercuryFixed,
            (281.0103 * u.deg, 61.45 * u.deg, 329.5999488 * u.deg),
        ),
        (Venus, VenusFixed, (272.76 * u.deg, 67.16 * u.deg, 160.2 * u.deg)),
        (
            Mars,
            MarsFixed,
            (317.68085441 * u.deg, 52.88643928 * u.deg, 176.63205973 * u.deg),
        ),
        (
            Jupiter,
            JupiterFixed,
            (268.05720404 * u.deg, 64.49580995 * u.deg, 284.95 * u.deg),
        ),
        (Saturn, SaturnFixed, (40.589 * u.deg, 83.537 * u.deg, 38.9 * u.deg)),
        (
            Uranus,
            UranusFixed,
            (257.311 * u.deg, -15.175 * u.deg, 203.81 * u.deg),
        ),
        (
            Neptune,
            NeptuneFixed,
            (299.33373896 * u.deg, 42.95035902 * u.deg, 249.99600757 * u.deg),
        ),
        (
            Moon,
            MoonFixed,
            (
                266.85773344495135 * u.deg,
                65.64110274784535 * u.deg,
                41.1952639807452 * u.deg,
            ),
        ),
    ],
)
def test_fixed_frame_calculation_gives_expected_result(
    body, fixed_frame, radecW
):
    epoch = J2000
    fixed_position = fixed_frame(
        0 * u.deg,
        0 * u.deg,
        body.R,
        obstime=epoch,
        representation_type="spherical",
    )

    assert_quantity_allclose(
        fixed_position.rot_elements_at_epoch(), radecW, atol=1e-7 * u.deg
    )


# the NotImplementedError raises only for the Sun, but maybe we want to check something different later
@pytest.mark.parametrize(
    "body, frame",
    [
        (Sun, HCRS),
    ],
)
def test_get_frame(body, frame):
    with pytest.raises(NotImplementedError) as excinfo:
        get_frame(body, Planes.BODY_FIXED)
    assert "NotImplementedError: A frame with plane" in excinfo.exconly()


def test_planetary_fixed():
    epoch = J2000
    with pytest.raises(NotImplementedError) as excinfo:
        _PlanetaryFixed._rot_elements_at_epoch(epoch, 123)
    assert "NotImplementedError" in excinfo.exconly()

    mf = MercuryFixed(obstime=epoch)
    vf = VenusFixed(obstime=epoch)
    sf = SunFixed(obstime=epoch)

    inertial_position = GCRS(
        0 * u.deg,
        0 * u.deg,
        Mercury.R,
        obstime=epoch,
        representation_type="spherical",
    )
    fixed_position = inertial_position.transform_to(mf)
    sf_position = inertial_position.transform_to(sf)

    # do some strange things in order to get exceptions, this is not really meaningful
    test_from_equ = _PlanetaryFixed.from_equatorial(
        fixed_position, mf
    )  # this should only work without exception
    with pytest.raises(
        ValueError,
        match="Fixed and equatorial coordinates must have the same body if the fixed frame body is not Sun",
    ):
        _PlanetaryFixed.from_equatorial(fixed_position, vf)
    with pytest.raises(
        ValueError, match="Equatorial coordinates must be of type `HCRS`, got"
    ):
        _PlanetaryFixed.from_equatorial(fixed_position, sf)

    test_to_equ = _PlanetaryFixed.to_equatorial(test_from_equ, mf)
    with pytest.raises(
        ValueError,
        match="Fixed and equatorial coordinates must have the same body if the fixed frame body is not Sun",
    ):
        _PlanetaryFixed.to_equatorial(fixed_position, vf)
    with pytest.raises(
        ValueError, match="Equatorial coordinates must be of type `HCRS`, got"
    ):
        _PlanetaryFixed.to_equatorial(sf_position, mf)

    assert_quantity_allclose(fixed_position.ra, test_to_equ.ra)
    assert_quantity_allclose(fixed_position.dec, test_to_equ.dec)
    assert_quantity_allclose(fixed_position.distance, test_to_equ.distance)


def test_planetary_icrs_class():
    epoch = J2000

    mf = MercuryICRS(obstime=epoch)
    vf = VenusICRS(obstime=epoch)
    sf = SunFixed(obstime=epoch)

    inertial_position = GCRS(
        0 * u.deg,
        0 * u.deg,
        Mercury.R,
        obstime=epoch,
        representation_type="spherical",
    )
    icrs_position = inertial_position.transform_to(mf)
    inertial_position.transform_to(sf)

    # do some strange things in order to get exceptions, this is not really meaningful
    _PlanetaryICRS.from_icrs(
        icrs_position, mf
    )  # this should only work without exception
    #    with pytest.raises(ValueError, match="Fixed and equatorial coordinates must have the same body if the fixed frame body is not Sun"):
    _PlanetaryICRS.from_icrs(icrs_position, vf)


# todo: what can be done here?
#    test_to_icrs=_PlanetaryICRS.to_icrs(test_from_icrs, mf)
#    with pytest.raises(ValueError, match="Fixed and equatorial coordinates must have the same body if the fixed frame body is not Sun"):
#    to_icrs=_PlanetaryICRS.to_icrs(icrs_position, vf)
#
#    assert_quantity_allclose(fixed_position.ra, test_to_equ.ra)
#    assert_quantity_allclose(fixed_position.dec, test_to_equ.dec)
#    assert_quantity_allclose(fixed_position.distance, test_to_equ.distance)
