from unidecode import unidecode
import re
import sys
import inflection
import numpy as np
import math
from collections import defaultdict

# Using cosine_similarity, own faster implementation, inspired by
# https://towardsdatascience.com/calculating-string-similarity-in-python-276e18a7d33a

_tokens_cache = defaultdict(lambda: None)
_phone_regex = re.compile(r'[^\d]|^0+')
# last is a subset from string.punctuation
_nopunctuation = str.maketrans('()[]-&:;./-.', '            ', '\'`ยด!"#$%*+,<=>?@\\^_`{|}~')


def cosine_similarity(text1, text2, cache=False, stopwords=None, enders=None):
    # Return cosine similarity between text1 and text2

    tok1 = tok2 = None
    if cache:
        tok1 = _tokens_cache[text1]
        tok2 = _tokens_cache[text2]

    if tok1 is None:
        tok1 = get_tokens(text1)
        if cache:
            _tokens_cache[text1] = tok1
    if tok2 is None:
        tok2 = get_tokens(text2)
        if cache:
            _tokens_cache[text2] = tok2

    if not tok1 or not tok2:
        return 0.0
    if tok1 == tok2:
        return 1.0

    vocabulary = set(tok1 + tok2)
    if len(vocabulary) == len(tok1) + len(tok2):
        # No intersections
        return 0.0
    v1 = np.zeros(len(vocabulary))
    v2 = np.zeros(len(vocabulary))
    for i, w in enumerate(vocabulary):
        if w in tok1:
            v1[i] = 1
        if w in tok2:
            v2[i] = 1
    # This the cosine = v1 DOT v2 / (norm-2(v1) * norm-2(v2))
    # equivalent but +2x faster than np.dot(v1, v2) / (np.linalg.norm(v1) *  np.linalg.norm(v2))
    return np.dot(v1, v2) / (math.sqrt(np.dot(v1, v1)) * math.sqrt(np.dot(v2, v2)))


def get_tokens(text, stopwords=None, enders=None):
    text = text.translate(_nopunctuation)
    text = unidecode(text)
    text = text.lower()
    tokens = [inflection.singularize(w) for w in text.split() if len(w) > 1 and (not stopwords or w not in stopwords)]
    if enders:
        for i, w in enumerate(tokens):
            if w in enders:
                tokens = tokens[:i]
                break
    return sorted(tokens)


def phonenumber_equal(a, b):
    a = _phone_regex.sub('', a)
    b = _phone_regex.sub('', b)
    if len(a) > 8 or len(b) > 8 and a == b:  # Only if at least they hav 9 digits
        return True
    return False