// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_
#define DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_
#include "nearest_neighbor_feature_image_abstract.h"
#include <vector>
#include "../algs.h"
#include "../matrix.h"
#include "../statistics.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    class nearest_neighbor_feature_image : noncopyable
    {
        /*!
            INITIAL VALUE
                - nn_feats.size() == 1
            CONVENTION
                - nn_feats.size() == 1
        !*/
    public:
        typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
        nearest_neighbor_feature_image (
        ); 
        void clear (
        );
        void copy_configuration (
            const feature_extractor& item
        );
        void copy_configuration (
            const nearest_neighbor_feature_image& item
        );
        template <
            typename image_type
            >
        inline void load (
            const image_type& img
        );
        inline size_t size (
        ) const;
        inline long nr (
        ) const;
        inline long nc (
        ) const;
        inline long get_num_dimensions (
        ) const;
        template <typename vector_type>
        void set_basis (
            const vector_type& new_basis
        );
        inline const descriptor_type& operator() (
            long row,
            long col
        ) const;
        inline const rectangle get_block_rect (
            long row,
            long col
        ) const;
        inline const point image_to_feat_space (
            const point& p
        ) const;
        inline const rectangle image_to_feat_space (
            const rectangle& rect
        ) const;
        inline const point feat_to_image_space (
            const point& p
        ) const;
        inline const rectangle feat_to_image_space (
            const rectangle& rect
        ) const;
        template <typename T>
        friend void serialize (
            const nearest_neighbor_feature_image<T>& item,
            std::ostream& out
        );
        template <typename T>
        friend void deserialize (
            nearest_neighbor_feature_image<T>& item,
            std::istream& in 
        );
    private:
        array2d<unsigned long> feats;
        feature_extractor fe;
        std::vector<typename feature_extractor::descriptor_type> basis;
        // This is a transient variable.  It is just here so it doesn't have to be
        // reallocated over and over inside operator()
        mutable descriptor_type nn_feats;
    };
// ----------------------------------------------------------------------------------------
    template <typename T>
    void serialize (
        const nearest_neighbor_feature_image<T>& item,
        std::ostream& out
    )
    {
        serialize(item.feats, out);
        serialize(item.fe, out);
        serialize(item.basis, out);
    }
    template <typename T>
    void deserialize (
        nearest_neighbor_feature_image<T>& item,
        std::istream& in 
    )
    {
        deserialize(item.feats, in);
        deserialize(item.fe, in);
        deserialize(item.basis, in);
    }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
//                        nearest_neighbor_feature_image member functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    nearest_neighbor_feature_image<feature_extractor>::
    nearest_neighbor_feature_image (
    )  
    {
        nn_feats.resize(1);
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    void nearest_neighbor_feature_image<feature_extractor>::
    clear (
    )
    {
        feats.clear();
        fe.clear();
        basis.clear();
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    void nearest_neighbor_feature_image<feature_extractor>::
    copy_configuration (
        const feature_extractor& item
    )
    {
        fe.copy_configuration(item);
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    void nearest_neighbor_feature_image<feature_extractor>::
    copy_configuration (
        const nearest_neighbor_feature_image& item
    )
    {
        fe.copy_configuration(item.fe);
        basis = item.basis;
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    template <
        typename image_type
        >
    void nearest_neighbor_feature_image<feature_extractor>::
    load (
        const image_type& img
    )
    {
        fe.load(img);
        feats.set_size(fe.nr(), fe.nc());
        // find the nearest neighbor for each feature vector and store the
        // result in feats.
        for (long r = 0; r < feats.nr(); ++r)
        {
            for (long c = 0; c < feats.nc(); ++c)
            {
                const typename feature_extractor::descriptor_type& local_feat = fe(r,c);
                double best_dist = std::numeric_limits<double>::infinity();
                unsigned long best_idx = 0;
                for (unsigned long i = 0; i < basis.size(); ++i)
                {
                    double dist = length_squared(local_feat - basis[i]);
                    if (dist < best_dist)
                    {
                        best_dist = dist;
                        best_idx = i;
                    }
                }
                feats[r][c] = best_idx;
            }
        }
        fe.unload();
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    size_t nearest_neighbor_feature_image<feature_extractor>::
    size (
    ) const
    {
        return feats.size();
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    long nearest_neighbor_feature_image<feature_extractor>::
    nr (
    ) const
    {
        return feats.nr();
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    long nearest_neighbor_feature_image<feature_extractor>::
    nc (
    ) const
    {
        return feats.nc();
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    long nearest_neighbor_feature_image<feature_extractor>::
    get_num_dimensions (
    ) const
    {
        return basis.size();
    }
// ----------------------------------------------------------------------------------------
    template <typename feature_extractor>
    template <typename vector_type>
    void nearest_neighbor_feature_image<feature_extractor>::
    set_basis (
        const vector_type& new_basis
    )
    {
        basis.assign(new_basis.begin(), new_basis.end());
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    const typename nearest_neighbor_feature_image<feature_extractor>::descriptor_type& 
    nearest_neighbor_feature_image<feature_extractor>::
    operator() (
        long row,
        long col
    ) const
    {
        // make sure requires clause is not broken
        DLIB_ASSERT(0 <= row && row < nr() &&
                    0 <= col && col < nc(),
            "\t descriptor_type nearest_neighbor_feature_image::operator(row,col)"
            << "\n\t Invalid inputs were given to this function"
            << "\n\t row:  " << row
            << "\n\t col:  " << col 
            << "\n\t nr(): " << nr()
            << "\n\t nc(): " << nc()
            << "\n\t this: " << this
            );
        nn_feats[0] = std::make_pair(feats[row][col],1);
        return nn_feats;
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    const rectangle nearest_neighbor_feature_image<feature_extractor>::
    get_block_rect (
        long row,
        long col
    ) const
    {
        return fe.get_block_rect(row,col);
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    const point nearest_neighbor_feature_image<feature_extractor>::
    image_to_feat_space (
        const point& p
    ) const
    {
        return fe.image_to_feat_space(p);
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    const rectangle nearest_neighbor_feature_image<feature_extractor>::
    image_to_feat_space (
        const rectangle& rect
    ) const
    {
        return fe.image_to_feat_space(rect);
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    const point nearest_neighbor_feature_image<feature_extractor>::
    feat_to_image_space (
        const point& p
    ) const
    {
        return fe.feat_to_image_space(p);
    }
// ----------------------------------------------------------------------------------------
    template <
        typename feature_extractor
        >
    const rectangle nearest_neighbor_feature_image<feature_extractor>::
    feat_to_image_space (
        const rectangle& rect
    ) const 
    {
        return fe.feat_to_image_space(rect);
    }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_NEAREST_NEIGHBOR_FeATURE_IMAGE_Hh_