// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include "tester.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
typedef dlib::matrix<double,3,1> lhs_element;
typedef dlib::matrix<double,3,1> rhs_element;
namespace  
{
    using namespace test;
    using namespace dlib;
    using namespace std;
    logger dlog("test.assignment_learning");
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
    struct feature_extractor_dense
    {
        typedef matrix<double,3,1> feature_vector_type;
        typedef ::lhs_element lhs_element;
        typedef ::rhs_element rhs_element;
        unsigned long num_features() const
        {
            return 3;
        }
        void get_features (
            const lhs_element& left,
            const rhs_element& right,
            feature_vector_type& feats
        ) const
        {
            feats = squared(left - right);
        }
    };
    void serialize   (const feature_extractor_dense& , std::ostream& ) {}
    void deserialize (feature_extractor_dense&       , std::istream& ) {}
// ----------------------------------------------------------------------------------------
    struct feature_extractor_sparse
    {
        typedef std::vector<std::pair<unsigned long,double> > feature_vector_type;
        typedef ::lhs_element lhs_element;
        typedef ::rhs_element rhs_element;
        unsigned long num_features() const
        {
            return 3;
        }
        void get_features (
            const lhs_element& left,
            const rhs_element& right,
            feature_vector_type& feats
        ) const
        {
            feats.clear();
            feats.push_back(make_pair(0,squared(left-right)(0)));
            feats.push_back(make_pair(1,squared(left-right)(1)));
            feats.push_back(make_pair(2,squared(left-right)(2)));
        }
    };
    void serialize   (const feature_extractor_sparse& , std::ostream& ) {}
    void deserialize (feature_extractor_sparse&       , std::istream& ) {}
// ----------------------------------------------------------------------------------------
    typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
    typedef std::vector<long> label_type;
// ----------------------------------------------------------------------------------------
    void make_data (
        std::vector<sample_type>& samples,
        std::vector<label_type>& labels
    )
    {
        lhs_element a, b, c, d;
        a = 1,0,0;
        b = 0,1,0;
        c = 0,0,1;
        d = 0,1,1;
        std::vector<lhs_element> lhs;
        std::vector<rhs_element> rhs;
        label_type label;
        lhs.push_back(a);
        lhs.push_back(b);
        lhs.push_back(c);
        rhs.push_back(b);
        rhs.push_back(a);
        rhs.push_back(c);
        label.push_back(1);
        label.push_back(0);
        label.push_back(2);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        lhs.push_back(a);
        lhs.push_back(b);
        lhs.push_back(c);
        rhs.push_back(c);
        rhs.push_back(b);
        rhs.push_back(a);
        rhs.push_back(d);
        label.push_back(2);
        label.push_back(1);
        label.push_back(0);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        lhs.push_back(a);
        lhs.push_back(b);
        lhs.push_back(c);
        rhs.push_back(c);
        rhs.push_back(a);
        rhs.push_back(d);
        label.push_back(1);
        label.push_back(-1);
        label.push_back(0);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        lhs.push_back(d);
        lhs.push_back(b);
        lhs.push_back(c);
        label.push_back(-1);
        label.push_back(-1);
        label.push_back(-1);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
    }
// ----------------------------------------------------------------------------------------
    void make_data_force (
        std::vector<sample_type>& samples,
        std::vector<label_type>& labels
    )
    {
        lhs_element a, b, c, d;
        a = 1,0,0;
        b = 0,1,0;
        c = 0,0,1;
        d = 0,1,1;
        std::vector<lhs_element> lhs;
        std::vector<rhs_element> rhs;
        label_type label;
        lhs.push_back(a);
        lhs.push_back(b);
        lhs.push_back(c);
        rhs.push_back(b);
        rhs.push_back(a);
        rhs.push_back(c);
        label.push_back(1);
        label.push_back(0);
        label.push_back(2);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        lhs.push_back(a);
        lhs.push_back(b);
        lhs.push_back(c);
        rhs.push_back(c);
        rhs.push_back(b);
        rhs.push_back(a);
        rhs.push_back(d);
        label.push_back(2);
        label.push_back(1);
        label.push_back(0);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        lhs.push_back(a);
        lhs.push_back(c);
        rhs.push_back(c);
        rhs.push_back(a);
        label.push_back(1);
        label.push_back(0);
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
        lhs.clear();
        rhs.clear();
        label.clear();
        samples.push_back(make_pair(lhs,rhs));
        labels.push_back(label);
    }
// ----------------------------------------------------------------------------------------
    template <typename fe_type, typename F>
    void test1(F make_data, bool force_assignment)
    {
        print_spinner();
        std::vector<sample_type> samples;
        std::vector<label_type> labels;
        make_data(samples, labels);
        make_data(samples, labels);
        make_data(samples, labels);
        randomize_samples(samples, labels);
        structural_assignment_trainer<fe_type> trainer;
        DLIB_TEST(trainer.forces_assignment() == false);
        DLIB_TEST(trainer.get_c() == 100);
        DLIB_TEST(trainer.get_num_threads() == 2);
        DLIB_TEST(trainer.get_max_cache_size() == 5);
        trainer.set_forces_assignment(force_assignment);
        trainer.set_num_threads(3);
        trainer.set_c(50);
        DLIB_TEST(trainer.get_c() == 50);
        DLIB_TEST(trainer.get_num_threads() == 3);
        DLIB_TEST(trainer.forces_assignment() == force_assignment);
        assignment_function<fe_type> ass = trainer.train(samples, labels);
        for (unsigned long i = 0; i < samples.size(); ++i)
        {
            std::vector<long> out = ass(samples[i]);
            dlog << LINFO << "true labels: " << trans(mat(labels[i]));
            dlog << LINFO << "pred labels: " << trans(mat(out));
            DLIB_TEST(trans(mat(labels[i])) == trans(mat(out)));
        }
        double accuracy;
        dlog << LINFO << "samples.size(): "<< samples.size();
        accuracy = test_assignment_function(ass, samples, labels);
        dlog << LINFO << "accuracy: "<< accuracy;
        DLIB_TEST(accuracy == 1);
        accuracy = cross_validate_assignment_trainer(trainer, samples, labels, 3);
        dlog << LINFO << "cv accuracy: "<< accuracy;
        DLIB_TEST(accuracy == 1);
        ostringstream sout;
        serialize(ass, sout);
        istringstream sin(sout.str());
        assignment_function<fe_type> ass2;
        deserialize(ass2, sin);
        DLIB_TEST(ass2.forces_assignment() == ass.forces_assignment());
        DLIB_TEST(length(ass2.get_weights() - ass.get_weights()) < 1e-10);
        for (unsigned long i = 0; i < samples.size(); ++i)
        {
            std::vector<long> out = ass2(samples[i]);
            dlog << LINFO << "true labels: " << trans(mat(labels[i]));
            dlog << LINFO << "pred labels: " << trans(mat(out));
            DLIB_TEST(trans(mat(labels[i])) == trans(mat(out)));
        }
    }
// ----------------------------------------------------------------------------------------
    class test_assignment_learning : public tester
    {
    public:
        test_assignment_learning (
        ) :
            tester ("test_assignment_learning",
                    "Runs tests on the assignment learning code.")
        {}
        void perform_test (
        )
        {
            test1<feature_extractor_dense>(make_data, false);
            test1<feature_extractor_sparse>(make_data, false);
            test1<feature_extractor_dense>(make_data_force, false);
            test1<feature_extractor_sparse>(make_data_force, false);
            test1<feature_extractor_dense>(make_data_force, true);
            test1<feature_extractor_sparse>(make_data_force, true);
        }
    } a;
// ----------------------------------------------------------------------------------------
}