// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
#include <vector>
#include "../matrix.h"
#include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream>
namespace dlib
{
// ----------------------------------------------------------------------------------------
    template <
        typename dec_funct_type,
        typename sample_type,
        typename label_type
        >
    const matrix<double> test_multiclass_decision_function (
        const dec_funct_type& dec_funct,
        const std::vector<sample_type>& x_test,
        const std::vector<label_type>& y_test
    )
    {
        // make sure requires clause is not broken
        DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
                    "\tmatrix test_multiclass_decision_function()"
                    << "\n\t invalid inputs were given to this function"
                    << "\n\t is_learning_problem(x_test,y_test): " 
                    << is_learning_problem(x_test,y_test));
        const std::vector<label_type> all_labels = dec_funct.get_labels();
        // make a lookup table that maps from labels to their index in all_labels
        std::map<label_type,unsigned long> label_to_int;
        for (unsigned long i = 0; i < all_labels.size(); ++i)
            label_to_int[all_labels[i]] = i;
        matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
        res.set_size(all_labels.size(), all_labels.size());
        res = 0;
        typename std::map<label_type,unsigned long>::const_iterator iter;
        // now test this trained object 
        for (unsigned long i = 0; i < x_test.size(); ++i)
        {
            iter = label_to_int.find(y_test[i]);
            // ignore samples with labels that the decision function doesn't know about.
            if (iter == label_to_int.end())
                continue;
            const unsigned long truth = iter->second;
            const unsigned long pred  = label_to_int[dec_funct(x_test[i])];
            res(truth,pred) += 1;
        }
        return res;
    }
// ----------------------------------------------------------------------------------------
    class cross_validation_error : public dlib::error 
    { 
    public: 
        cross_validation_error(const std::string& msg) : dlib::error(msg){};
    };
    template <
        typename trainer_type,
        typename sample_type,
        typename label_type 
        >
    const matrix<double> cross_validate_multiclass_trainer (
        const trainer_type& trainer,
        const std::vector<sample_type>& x,
        const std::vector<label_type>& y,
        const long folds
    )
    {
        typedef typename trainer_type::mem_manager_type mem_manager_type;
        // make sure requires clause is not broken
        DLIB_ASSERT(is_learning_problem(x,y) == true &&
                    1 < folds && folds <= static_cast<long>(x.size()),
            "\tmatrix cross_validate_multiclass_trainer()"
            << "\n\t invalid inputs were given to this function"
            << "\n\t x.size(): " << x.size() 
            << "\n\t folds:  " << folds 
            << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
            );
        const std::vector<label_type> all_labels = select_all_distinct_labels(y);
        // count the number of times each label shows up 
        std::map<label_type,long> label_counts;
        for (unsigned long i = 0; i < y.size(); ++i)
            label_counts[y[i]] += 1;
        // figure out how many samples from each class will be in the test and train splits 
        std::map<label_type,long> num_in_test, num_in_train;
        for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
        {
            const long in_test = i->second/folds;
            if (in_test == 0)
            {
                std::ostringstream sout;
                sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
                sout << "than the number of elements of one of the training classes." << std::endl;
                sout << "  folds: "<< folds << std::endl;
                sout << "  size of class " << i->first << ": "<< i->second << std::endl;
                throw cross_validation_error(sout.str());
            }
            num_in_test[i->first] = in_test; 
            num_in_train[i->first] = i->second - in_test;
        }
        std::vector<sample_type> x_test, x_train;
        std::vector<label_type> y_test, y_train;
        matrix<double, 0, 0, mem_manager_type> res;
        std::map<label_type,long> next_test_idx;
        for (unsigned long i = 0; i < all_labels.size(); ++i)
            next_test_idx[all_labels[i]] = 0;
        label_type label;
        for (long i = 0; i < folds; ++i)
        {
            x_test.clear();
            y_test.clear();
            x_train.clear();
            y_train.clear();
            // load up the test samples
            for (unsigned long j = 0; j < all_labels.size(); ++j)
            {
                label = all_labels[j];
                long next = next_test_idx[label];
                long cur = 0;
                const long num_needed = num_in_test[label];
                while (cur < num_needed)
                {
                    if (y[next] == label)
                    {
                        x_test.push_back(x[next]);
                        y_test.push_back(label);
                        ++cur;
                    }
                    next = (next + 1)%x.size();
                }
                next_test_idx[label] = next;
            }
            // load up the training samples
            for (unsigned long j = 0; j < all_labels.size(); ++j)
            {
                label = all_labels[j];
                long next = next_test_idx[label];
                long cur = 0;
                const long num_needed = num_in_train[label];
                while (cur < num_needed)
                {
                    if (y[next] == label)
                    {
                        x_train.push_back(x[next]);
                        y_train.push_back(label);
                        ++cur;
                    }
                    next = (next + 1)%x.size();
                }
            }
            try
            {
                // do the training and testing
                res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test);
            }
            catch (invalid_nu_error&)
            {
                // just ignore cases which result in an invalid nu
            }
        } // for (long i = 0; i < folds; ++i)
        return res;
    }
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_