from datetime import datetime
from unittest.mock import call, patch
from brian_polling_manager import influx
from brian_polling_manager.interface_stats.vendors import common, nokia
import pytest
from lxml import etree


@pytest.fixture
def get_netconf(data_dir):
    def _get_netconf(router_name):
        return nokia.get_netconf_interface_info_from_source_dir(
            router_name, source_dir=data_dir
        )

    return _get_netconf


NOKIA_PORT_XML = """\
<rpc-reply>
  <data>
    <state>
      <port>
        <port-id>1/1/c1</port-id>
        <oper-state>up</oper-state>
        <statistics>
          <in-octets>1</in-octets>
          <in-packets>2</in-packets>
          <out-octets>3</out-octets>
          <out-packets>4</out-packets>
          <in-errors>5</in-errors>
          <in-discards>6</in-discards>
          <out-errors>7</out-errors>
          <out-discards>8</out-discards>
        </statistics>
        <ethernet>
          <oper-state-change-count>20</oper-state-change-count>
          <statistics>
            <crc-align-errors>21</crc-align-errors>
            <ethernet-like-medium>
              <error>
                <fcs>22</fcs>
              </error>
            </ethernet-like-medium>
          </statistics>
        </ethernet>
      </port>
    </state>
  </data>
</rpc-reply>"""

NOKIA_LAG_XML = """\
<rpc-reply>
  <data>
    <state>
      <lag>
        <lag-name>lag-1</lag-name>
        <oper-state>up</oper-state>
        <statistics>
          <in-octets>1</in-octets>
          <in-packets>2</in-packets>
          <out-octets>3</out-octets>
          <out-packets>4</out-packets>
          <in-errors>5</in-errors>
          <in-discards>6</in-discards>
          <out-errors>7</out-errors>
          <out-discards>8</out-discards>
        </statistics>
      </lag>
    </state>
  </data>
</rpc-reply>"""

NOKIA_ROUTER_INTERFACE_XML = """
<rpc-reply>
  <data>
    <state>
      <router>
        <interface>
          <interface-name>lag-1.0</interface-name>
          <statistics>
            <ip>
              <in-octets>11</in-octets>
              <in-packets>12</in-packets>
              <out-octets>13</out-octets>
              <out-packets>14</out-packets>
              <out-discard-packets>15</out-discard-packets>
            </ip>
          </statistics>
          <ipv6>
            <statistics>
              <in-octets>16</in-octets>
              <in-packets>17</in-packets>
              <out-octets>18</out-octets>
              <out-packets>19</out-packets>
            </statistics>
          </ipv6>
        </interface>
      </router>
    </state>
  </data>
</rpc-reply>
"""


def nokia_docs_containing_every_field():
    return {
        "port": etree.fromstring(NOKIA_PORT_XML),
        "lag": etree.fromstring(NOKIA_LAG_XML),
        "router-interface": etree.fromstring(NOKIA_ROUTER_INTERFACE_XML),
    }


def test_nokia_counters():
    result = list(nokia.interface_counters(nokia_docs_containing_every_field()))
    assert result == [
        {
            "name": "1/1/c1",
            "brian": {
                "ingressOctets": 1,
                "ingressPackets": 2,
                "egressOctets": 3,
                "egressPackets": 4,
                "ingressErrors": 5,
                "ingressDiscards": 6,
                "egressErrors": 7,
            },
            "errors": {
                "input_total_errors": 5,
                "input_discards": 6,
                "output_total_errors": 7,
                "output_discards": 8,
                "oper_state_change_count": 20,
                "crc_align_errors": 21,
                "fcs_errors": 22,
            },
        },
        {
            "name": "lag-1",
            "brian": {
                "ingressOctets": 1,
                "ingressPackets": 2,
                "egressOctets": 3,
                "egressPackets": 4,
                "ingressErrors": 5,
                "ingressDiscards": 6,
                "egressErrors": 7,
            },
            "errors": {
                "input_total_errors": 5,
                "input_discards": 6,
                "output_total_errors": 7,
                "output_discards": 8,
            },
        },
        {
            "name": "lag-1.0",
            "brian": {
                "ingressOctets": 11,
                "ingressPackets": 12,
                "egressOctets": 13,
                "egressPackets": 14,
                "ingressOctetsv6": 16,
                "ingressPacketsv6": 17,
                "egressOctetsv6": 18,
                "egressPacketsv6": 19,
            },
            "errors": {
                "output_discards": 15,
            },
        },
    ]


def test_nokia_router_docs_do_not_generate_errors(
    nokia_router_fqdn, caplog, get_netconf
):
    doc = get_netconf(nokia_router_fqdn)
    counters = list(nokia.interface_counters(doc))
    assert counters
    assert not [r for r in caplog.records if r.levelname in ("ERROR", "WARNING")]


def test_validate_interface_counters_and_influx_points_for_all_nokia_routers(
    nokia_router_fqdn, get_netconf, schemavalidate
):
    doc = get_netconf(nokia_router_fqdn)
    interfaces = list(nokia.interface_counters(doc))
    assert interfaces
    for ifc in interfaces:
        schemavalidate(ifc, common.INTERFACE_COUNTER_SCHEMA)

    bpoints = list(
        common.brian_points(
            nokia_router_fqdn,
            interfaces,
            timestamp=datetime.now(),
            measurement_name="blah",
        )
    )
    assert bpoints
    for point in bpoints:
        schemavalidate(point, influx.INFLUX_POINT)
        schemavalidate(point["fields"], common.BRIAN_POINT_FIELDS_SCHEMA)

    epoints = list(
        common.error_points(
            nokia_router_fqdn,
            interfaces,
            timestamp=datetime.now(),
            measurement_name="blah",
        )
    )
    assert epoints
    for point in epoints:
        schemavalidate(point, influx.INFLUX_POINT)
        schemavalidate(point["fields"], common.ERROR_POINT_FIELDS_SCHEMA)


def test_processes_specific_interfaces(get_netconf, caplog):
    doc = get_netconf("rt0.lon2.uk.geant.net")
    interfaces = ["1/1/c1", "lag-1", "lag-1.0", "lag-20", "lag-20.1", "lag-20.111"]
    result = list(nokia.interface_counters(doc, interfaces=interfaces))
    assert len(result) == 6
    assert all(isinstance(i, dict) for i in result)
    assert not [r for r in caplog.records if r.levelname in ("ERROR", "WARNING")]


class TestGetNokiaNetconf:
    RAW_RESPONSE_FILE = "raw-response-nokia-sample.xml"

    @pytest.fixture(autouse=True)
    def mocked_connection(self, data_dir):
        raw_response = (data_dir / self.RAW_RESPONSE_FILE).read_bytes()
        with patch.object(nokia, "netconf_connect") as mock:
            mock().__enter__().get().tostring = raw_response
            mock.reset_mock()
            yield mock

    def test_connect_with_params(self, mocked_connection):
        nokia.get_netconf_interface_info("some.router", ssh_params={"some": "param"})

        assert mocked_connection.call_args == call(
            hostname="some.router",
            ssh_params={"some": "param"},
            device_params={"name": "sros"},
            nc_params={"capabilities": ["urn:nokia.com:nc:pysros:pc"]},
            timeout=60,
        )

    def test_calls_get_with_request_filter(self, mocked_connection):
        router_name = "some-router"
        nokia.get_netconf_interface_info(router_name, ssh_params={"some": "param"})

        calls = mocked_connection().__enter__().get.call_args_list
        assert len(calls) == 5

        def _get_element_tags(call):
            elems = call[1]["filter"].iter()
            return [e.tag for e in elems]

        assert _get_element_tags(calls[0]) == [
            "filter",
            "{urn:nokia.com:sros:ns:yang:sr:state}state",
            "{urn:nokia.com:sros:ns:yang:sr:state}port",
            "{urn:nokia.com:sros:ns:yang:sr:state}statistics",
            "{urn:nokia.com:sros:ns:yang:sr:state}ethernet",
        ]

        assert _get_element_tags(calls[1]) == [
            "filter",
            "{urn:nokia.com:sros:ns:yang:sr:state}state",
            "{urn:nokia.com:sros:ns:yang:sr:state}lag",
            "{urn:nokia.com:sros:ns:yang:sr:state}statistics",
        ]

        assert _get_element_tags(calls[2]) == [
            "filter",
            "{urn:nokia.com:sros:ns:yang:sr:state}state",
            "{urn:nokia.com:sros:ns:yang:sr:state}router",
            "{urn:nokia.com:sros:ns:yang:sr:state}interface",
        ]
        assert _get_element_tags(calls[3]) == [
            "filter",
            "{urn:nokia.com:sros:ns:yang:sr:state}state",
            "{urn:nokia.com:sros:ns:yang:sr:state}service",
            "{urn:nokia.com:sros:ns:yang:sr:state}vprn",
            "{urn:nokia.com:sros:ns:yang:sr:state}interface",
        ]
        assert _get_element_tags(calls[4]) == [
            "filter",
            "{urn:nokia.com:sros:ns:yang:sr:state}state",
            "{urn:nokia.com:sros:ns:yang:sr:state}service",
            "{urn:nokia.com:sros:ns:yang:sr:state}ies",
            "{urn:nokia.com:sros:ns:yang:sr:state}interface",
        ]

    def test_converts_rpc_response_to_xml(self):
        router_name = "some-router"

        doc = nokia.get_netconf_interface_info(
            router_name, ssh_params={"some": "param"}
        )

        assert doc["port"].tag == "rpc-reply"
        assert doc["lag"].tag == "rpc-reply"
        assert doc["router-interface"].tag == "rpc-reply"