import xml.etree.ElementTree as Etree
import warnings, re
from html.entities import name2codepoint
from html.parser import HTMLParser

__all__ = ["HTMLement", "fromstring", "fromstringlist", "parse"]
__version__ = "2.0.0"
name2codepoint["apos"] = 0x0027

def fromstring(text, tag="", attrs=None, encoding=None):
    parser = HTMLement(tag, attrs, encoding)
    parser.feed(text)
    return parser.close()

def fromstringlist(sequence, tag="", attrs=None, encoding=None):
    parser = HTMLement(tag, attrs, encoding)
    for text in sequence:
        parser.feed(text)
    return parser.close()

def parse(source, tag="", attrs=None, encoding=None):
    if not hasattr(source, "read"):
        source = open(source, "r", encoding=encoding)
        close_source = True
    else:
        close_source = False
    try:
        parser = HTMLement(tag, attrs, encoding)
        while True:
            data = source.read(65536)
            if not data:
                break
            parser.feed(data)
        return parser.close()
    finally:
        if close_source:
            source.close()

class HTMLement(object):
    def __init__(self, tag="", attrs=None, encoding=None):
        self._parser = ParseHTML(tag, attrs)
        self.encoding = encoding
        self._finished = False

    def feed(self, data):
        if self._finished == 1:
            return None
        if isinstance(data, bytes):
            if self.encoding:
                data = data.decode(self.encoding)
            else:
                data = self._make_unicode(data)
        try:
            self._parser.feed(data)
        except EOFError:
            self._finished = True
            self._parser.reset()

    def close(self):
        return self._parser.close()

    def _make_unicode(self, data):
        end_head_tag = data.find(b"</head>")
        if end_head_tag:
            charset_refind = b'<meta.+?charset=[\'"]*(.+?)["\'].*?>'
            charset = re.search(charset_refind, data[:end_head_tag], re.IGNORECASE)
            if charset:
                self.encoding = encoding = charset.group(1).decode()
                return data.decode(encoding)
        warn_msg = "Unable to determine encoding, defaulting to iso-8859-1"
        warnings.warn(warn_msg, UnicodeWarning, stacklevel=2)
        self.encoding = "iso-8859-1"
        return data.decode("iso-8859-1")

class ParseHTML(HTMLParser):
    def __init__(self, tag="", attrs=None):
        HTMLParser.__init__(self)
        self.convert_charrefs = True
        self._root = None
        self._data = []
        self._factory = Etree.Element
        self.enabled = not tag
        self._unw_attrs = []
        self.tag = tag
        if attrs:
            self.attrs = attrs
            for key, value in attrs.copy().items():
                if value == 0:
                    self._unw_attrs.append(key)
                    del attrs[key]
        else:
            self.attrs = {}
        self._voids = frozenset(("area", "base", "br", "col", "hr", "img", "input", "link", "meta", "param",
                                 "embed", "keygen", "source", "track",
                                 "basefont", "frame", "isindex",
                                 "rect", "circle", "ellipse", "line", "polyline", "polygon",
                                 "path", "stop", "use", "image", "animatetransform"))
        elem = self._factory("html")
        self._elem = [elem]
        self._last = elem
        self._tail = 0

    def handle_starttag(self, tag, attrs):
        self._handle_starttag(tag, attrs, self_closing=tag in self._voids)

    def handle_startendtag(self, tag, attrs):
        self._handle_starttag(tag, attrs, self_closing=True)

    def _handle_starttag(self, tag, attrs, self_closing=False):
        enabled = self.enabled
        if enabled or self._search(tag, attrs):
            attrs = {k: v or "" for k, v in attrs}
            self._flush()
            elem = self._factory(tag, attrs)
            self._elem[-1].append(elem)
            self._last = elem
            if self_closing:
                self._tail = 1
            else:
                self._elem.append(elem)
                self._tail = 0
            if not enabled:
                self._root = elem
                self.enabled = True

    def handle_endtag(self, tag):
        if self.enabled and tag not in self._voids:
            _elem = self._elem
            _root = self._root
            if _elem[-1].tag == tag:
                self._flush()
                self._tail = 1
                self._last = elem = _elem.pop()
                if elem is _root:
                    raise EOFError
            elif len(_elem) >= 2 and any(_item.tag == tag for _item in _elem):
                self._flush()
                self._tail = 1
                while True:
                    self._last = elem = _elem.pop()
                    if elem.tag == tag:
                        break
                if elem is _root:
                    raise EOFError
            else:
                return None

    def handle_data(self, data):
        if data.strip() and self.enabled:
            self._data.append(data)

    def handle_entityref(self, name):
        if self.enabled:
            try:
                name = chr(name2codepoint[name])
            except KeyError:
                pass
            self._data.append(name)

    def handle_charref(self, name):
        if self.enabled:
            try:
                if name[0].lower() == "x":
                    name = chr(int(name[1:], 16))
                else:
                    name = chr(int(name))
            except ValueError:
                pass
            self._data.append(name)

    def handle_comment(self, data):
        data = data.strip()
        if data and self.enabled:
            elem = Etree.Comment(data)
            self._elem[-1].append(elem)

    def close(self):
        self._flush()
        if self.enabled == 0:
            msg = "Unable to find requested section with tag of '{}' and attributes of {}"
            raise RuntimeError(msg.format(self.tag, self.attrs))
        elif self._root is not None:
            return self._root
        else:
            tmp_root = self._elem[0]
            proper_root = tmp_root.find("html")
            if proper_root is None:
                return tmp_root
            else:
                return proper_root

    def _flush(self):
        if self._data:
            if self._last is not None:
                text = "".join(self._data)
                if self._tail:
                    self._last.tail = text
                else:
                    self._last.text = text
            self._data = []

    def _search(self, tag, attrs):
        if tag == self.tag:
            if self.attrs or self._unw_attrs:
                if attrs:
                    wanted_attrs = self.attrs.copy()
                    unwanted_attrs = self._unw_attrs
                    for key, value in attrs:
                        if key in unwanted_attrs:
                            return False
                        elif key in wanted_attrs:
                            c_value = wanted_attrs[key]
                            if c_value == value or c_value == 1:
                                del wanted_attrs[key]
                    if not wanted_attrs:
                        return True
            else:
                return True
        return False