from pystruct.models.base import StructuredModel
import numpy as np
import types

class ExternalModel(StructuredModel):
    """Interface definition for Structured Learners.
    This class defines what is necessary to use the structured svm.
    You have to implement at least joint_feature and inference.
    """
    def __init__(self, size_joint_feature=None, joint_feature=None,
                       inference=None, loss=None,
                       loss_augmented_inference=None, initialize=None):
        """Initialize the model.
        Needs to set self.size_joint_feature, the dimensionalty of the joint features for
        an instance with labeling (x, y).
        """
        self.size_joint_feature = size_joint_feature
        if joint_feature is not None:
            self.joint_feature = types.MethodType(joint_feature, self)
        if inference is not None:
            self.inference = types.MethodType(inference, self)
        if loss is not None:
            self.loss = types.MethodType(loss, self)
        if loss_augmented_inference is not None:
            self.loss_augmented_inference = types.MethodType(loss_augmented_inference, self)
        if initialize is not None:
            self.initialize = types.MethodType(initialize, self)

    def initialize(self, X, Y):
        pass

    def joint_feature(self, x, y):
        raise NotImplementedError()

    def inference(self, x, w, relaxed=None):
        raise NotImplementedError()

    def loss(self, y, y_hat):
        # hamming loss:
        return np.sum(y != y_hat)

    def loss_augmented_inference(self, x, y, w, relaxed=None):
        print("FALLBACK no loss augmented inference found")
        return self.inference(x, w)