00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
00038 #define OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
00039
00040 #include "ompl/datastructures/NearestNeighbors.h"
00041 #include "ompl/datastructures/GreedyKCenters.h"
00042 #include "ompl/util/Exception.h"
00043 #include <queue>
00044 #include <algorithm>
00045
00046 namespace ompl
00047 {
00048
00057 template<typename _T>
00058 class NearestNeighborsGNAT : public NearestNeighbors<_T>
00059 {
00060 protected:
00061
00062
00063 typedef std::pair<const _T*,double> DataDist;
00064 struct DataDistCompare
00065 {
00066 bool operator()(const DataDist& d0, const DataDist& d1)
00067 {
00068 return d0.second < d1.second;
00069 }
00070 };
00071 typedef std::priority_queue<DataDist, std::vector<DataDist>, DataDistCompare> NearQueue;
00072
00073
00074
00075 class Node;
00076 typedef std::pair<Node*,double> NodeDist;
00077 struct NodeDistCompare
00078 {
00079 bool operator()(const NodeDist& n0, const NodeDist& n1) const
00080 {
00081 return (n0.second - n0.first->maxRadius_) > (n1.second - n1.first->maxRadius_);
00082 }
00083 };
00084 typedef std::priority_queue<NodeDist, std::vector<NodeDist>, NodeDistCompare> NodeQueue;
00085
00086
00087 public:
00088 NearestNeighborsGNAT(unsigned int degree = 4, unsigned int minDegree = 2,
00089 unsigned int maxDegree = 6, unsigned int maxNumPtsPerLeaf = 50,
00090 unsigned int removedCacheSize = 50)
00091 : NearestNeighbors<_T>(), tree_(NULL), degree_(degree),
00092 minDegree_(std::min(degree,minDegree)), maxDegree_(std::max(maxDegree,degree)),
00093 maxNumPtsPerLeaf_(maxNumPtsPerLeaf), size_(0)
00094 {
00095 removed_.reserve(removedCacheSize);
00096 }
00097
00098 virtual ~NearestNeighborsGNAT(void)
00099 {
00100 delete tree_;
00101 }
00102
00104 virtual void setDistanceFunction(const typename NearestNeighbors<_T>::DistanceFunction &distFun)
00105 {
00106 NearestNeighbors<_T>::setDistanceFunction(distFun);
00107 pivotSelector_.setDistanceFunction(distFun);
00108 }
00109
00110 virtual void clear(void)
00111 {
00112 if (tree_) delete tree_;
00113 tree_ = NULL;
00114 }
00115
00116 virtual void add(const _T &data)
00117 {
00118 if (tree_)
00119 tree_->add(*this, data);
00120 else
00121 tree_ = new Node(NULL, degree_, data);
00122 size_++;
00123 }
00124
00126 virtual void add(const std::vector<_T> &data)
00127 {
00128 if (tree_)
00129 NearestNeighbors<_T>::add(data);
00130 else if (data.size()>0)
00131 {
00132 tree_ = new Node(NULL, degree_, data[0]);
00133 for (unsigned int i=1; i<data.size(); ++i)
00134 tree_->data_.push_back(data[i]);
00135 if (tree_->needToSplit(*this))
00136 tree_->split(*this);
00137 }
00138 size_ += data.size();
00139 }
00141 void rebuildDataStructure()
00142 {
00143 typename std::vector<_T>::iterator it;
00144 std::vector<_T> lst;
00145
00146 list(lst);
00147 while (removed_.size() > 0)
00148 {
00149 for (it = lst.begin(); it != lst.end(); it++)
00150 if (*it == *removed_.back())
00151 break;
00152 assert(it != lst.end());
00153 lst.erase(it);
00154 removed_.pop_back();
00155 }
00156 delete tree_;
00157 tree_ = NULL;
00158 size_ = 0;
00159 add(lst);
00160 }
00161 virtual bool remove(const _T &data)
00162 {
00163 if (!tree_) return false;
00164
00165 NearQueue nbhQueue;
00166
00167 nearestRInternal(data, std::numeric_limits<double>::epsilon(), nbhQueue);
00168 if (nbhQueue.size()==0)
00169 return false;
00170 while (nbhQueue.size()>0)
00171 {
00172 unsigned int i;
00173 const _T* elt = nbhQueue.top().first;
00174 for (i=0; i<removed_.size(); ++i)
00175 if (removed_[i] == elt)
00176 break;
00177 if (i==removed_.size())
00178 {
00179
00180 removed_.push_back(elt);
00181 break;
00182 }
00183 nbhQueue.pop();
00184 }
00185
00186
00187 if (removed_.size()==removed_.capacity())
00188 rebuildDataStructure();
00189 else
00190 size_--;
00191 return true;
00192 }
00193 virtual _T nearest(const _T &data) const
00194 {
00195 if (tree_)
00196 {
00197 std::vector<_T> nbh;
00198 nearestK(data, 1, nbh);
00199 if (!nbh.empty()) return nbh[0];
00200 }
00201 throw Exception("No elements found");
00202 }
00203
00204 virtual void nearestK(const _T &data, std::size_t k, std::vector<_T> &nbh) const
00205 {
00206 nbh.clear();
00207 if (k == 0) return;
00208 if (tree_)
00209 {
00210 NearQueue nbhQueue;
00211 nearestKInternal(data, k, nbhQueue);
00212 postprocessNearest(nbhQueue, nbh, k);
00213 }
00214 }
00215
00216 virtual void nearestR(const _T &data, double radius, std::vector<_T> &nbh) const
00217 {
00218 nbh.clear();
00219 if (tree_)
00220 {
00221 NearQueue nbhQueue;
00222 nearestRInternal(data, radius, nbhQueue);
00223 postprocessNearest(nbhQueue, nbh);
00224 }
00225 }
00226
00227 virtual std::size_t size(void) const
00228 {
00229 return size_;
00230 }
00231
00232 virtual void list(std::vector<_T> &data) const
00233 {
00234 data.clear();
00235 data.reserve(size());
00236 if (tree_)
00237 tree_->list(data);
00238 }
00239
00240 friend std::ostream& operator<<(std::ostream& out, const NearestNeighborsGNAT<_T>& gnat)
00241 {
00242 return gnat.tree_ ? (out << *gnat.tree_) : out;
00243 }
00244
00245 protected:
00246 typedef NearestNeighborsGNAT<_T> GNAT;
00247
00248 void nearestKInternal(const _T &data, std::size_t k, NearQueue& nbhQueue) const
00249 {
00250 double dist;
00251 NodeDist nodeDist;
00252 NodeQueue nodeQueue;
00253
00254 tree_->insertNeighborK(nbhQueue, k, tree_->pivot_,
00255 NearestNeighbors<_T>::distFun_(data, tree_->pivot_));
00256 tree_->nearestK(*this, data, k + removed_.size(), nbhQueue, nodeQueue);
00257 while (nodeQueue.size() > 0)
00258 {
00259 dist = nbhQueue.top().second;
00260 nodeDist = nodeQueue.top();
00261 nodeQueue.pop();
00262 if (nbhQueue.size() == k &&
00263 (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
00264 nodeDist.second < nodeDist.first->minRadius_ - dist))
00265 break;
00266 nodeDist.first->nearestK(*this, data, k + removed_.size(), nbhQueue, nodeQueue);
00267 }
00268 }
00269 void nearestRInternal(const _T &data, double radius, NearQueue& nbhQueue) const
00270 {
00271 double dist = radius;
00272 NodeQueue nodeQueue;
00273 NodeDist nodeDist;
00274
00275 tree_->insertNeighborR(nbhQueue, radius, tree_->pivot_,
00276 NearestNeighbors<_T>::distFun_(data, tree_->pivot_));
00277 tree_->nearestR(*this, data, radius, nbhQueue, nodeQueue);
00278 while (nodeQueue.size() > 0)
00279 {
00280 nodeDist = nodeQueue.top();
00281 nodeQueue.pop();
00282 if (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
00283 nodeDist.second < nodeDist.first->minRadius_ - dist)
00284 break;
00285 nodeDist.first->nearestR(*this, data, radius, nbhQueue, nodeQueue);
00286 }
00287 }
00288 void postprocessNearest(NearQueue& nbhQueue, std::vector<_T> &nbh,
00289 unsigned int k=std::numeric_limits<unsigned int>::max()) const
00290 {
00291 if (removed_.size()>0)
00292 {
00293 while (nbhQueue.size()>0 && nbh.size()<k)
00294 {
00295 unsigned int i;
00296 const _T* elt = nbhQueue.top().first;
00297 for (i=0; i<removed_.size(); ++i)
00298 if (removed_[i] == elt)
00299 break;
00300 if (i==removed_.size())
00301 nbh.push_back(*elt);
00302 nbhQueue.pop();
00303 }
00304 }
00305 else
00306 {
00307 typename std::vector<_T>::reverse_iterator it;
00308 nbh.resize(nbhQueue.size());
00309 for (it=nbh.rbegin(); it!=nbh.rend(); it++, nbhQueue.pop())
00310 *it = *nbhQueue.top().first;
00311 }
00312 }
00313
00314 class Node
00315 {
00316 public:
00317 Node(const Node* parent, int degree, const _T& pivot)
00318 : degree_(degree), pivot_(pivot),
00319 minRadius_(std::numeric_limits<double>::infinity()),
00320 maxRadius_(-minRadius_), minRange_(degree, minRadius_),
00321 maxRange_(degree, maxRadius_)
00322 {
00323 }
00324
00325 ~Node()
00326 {
00327 for (unsigned int i=0; i<children_.size(); ++i)
00328 delete children_[i];
00329 }
00330
00331 void add(GNAT& gnat, const _T& data)
00332 {
00333 if (children_.size()==0)
00334 {
00335 data_.push_back(data);
00336 if (needToSplit(gnat))
00337 if (gnat.removed_.size() > 0)
00338 gnat.rebuildDataStructure();
00339 else
00340 split(gnat);
00341 }
00342 else
00343 {
00344 double minDist = std::numeric_limits<double>::infinity();
00345 int minInd = -1;
00346
00347 for (unsigned int i=0; i<children_.size(); ++i)
00348 {
00349 double dist;
00350
00351 if ((dist = gnat.distFun_(data, children_[i]->pivot_)) < minDist)
00352 {
00353 minDist = dist;
00354 minInd = i;
00355 }
00356 if (children_[i]->minRange_[minInd] > dist)
00357 children_[i]->minRange_[minInd] = dist;
00358 if (children_[i]->maxRange_[minInd] < dist)
00359 children_[i]->maxRange_[minInd] = dist;
00360 }
00361 if (minDist < children_[minInd]->minRadius_)
00362 children_[minInd]->minRadius_ = minDist;
00363 if (minDist > children_[minInd]->maxRadius_)
00364 children_[minInd]->maxRadius_ = minDist;
00365
00366 children_[minInd]->add(gnat, data);
00367 }
00368 }
00369
00370 bool needToSplit(const GNAT& gnat) const
00371 {
00372 unsigned int sz = data_.size();
00373 return sz > gnat.maxNumPtsPerLeaf_ && sz > degree_;
00374 }
00375 void split(GNAT& gnat)
00376 {
00377 std::vector<std::vector<double> > dists;
00378 std::vector<unsigned int> pivots;
00379
00380 children_.reserve(degree_);
00381 gnat.pivotSelector_.kcenters(data_, degree_, pivots, dists);
00382 for(unsigned int i=0; i<pivots.size(); i++)
00383 children_.push_back(new Node(this, degree_, data_[pivots[i]]));
00384 degree_ = pivots.size();
00385 for (unsigned int j=0; j<data_.size(); ++j)
00386 {
00387 unsigned int k = 0;
00388 for (unsigned int i=1; i<degree_; ++i)
00389 if (dists[j][i] < dists[j][k])
00390 k = i;
00391 Node* child = children_[k];
00392 if (j != pivots[k])
00393 {
00394 child->data_.push_back(data_[j]);
00395 if (dists[j][k] > child->maxRadius_)
00396 child->maxRadius_ = dists[j][k];
00397 if (dists[j][k] < child->minRadius_)
00398 child->minRadius_ = dists[j][k];
00399 }
00400 for (unsigned int i=0; i<degree_; ++i)
00401 {
00402 if (children_[i]->minRange_[k] > dists[j][i])
00403 children_[i]->minRange_[k] = dists[j][i];
00404 if (children_[i]->maxRange_[k] < dists[j][i])
00405 children_[i]->maxRange_[k] = dists[j][i];
00406 }
00407 }
00408
00409 for (unsigned int i=0; i<degree_; ++i)
00410 {
00411
00412 children_[i]->degree_ = std::min(std::max(
00413 degree_ * (unsigned int)(children_[i]->data_.size() / data_.size()),
00414 gnat.minDegree_), gnat.maxDegree_);
00415
00416 if (children_[i]->minRadius_ == std::numeric_limits<double>::infinity())
00417 children_[i]->minRadius_ = children_[i]->maxRadius_ = 0.;
00418 }
00419 data_.clear();
00420
00421 for (unsigned int i=0; i<degree_; ++i)
00422 if (children_[i]->needToSplit(gnat))
00423 children_[i]->split(gnat);
00424 }
00425
00426 void insertNeighborK(NearQueue& nbh, std::size_t k, const _T& data, double dist) const
00427 {
00428 if (nbh.size() < k)
00429 nbh.push(std::make_pair(&data, dist));
00430 else if (dist < nbh.top().second)
00431 {
00432 nbh.pop();
00433 nbh.push(std::make_pair(&data, dist));
00434 }
00435 }
00436
00437 void nearestK(const GNAT& gnat, const _T &data, std::size_t k, NearQueue& nbh, NodeQueue& nodeQueue) const
00438 {
00439 for (unsigned int i=0; i<data_.size(); ++i)
00440 insertNeighborK(nbh, k, data_[i], gnat.distFun_(data, data_[i]));
00441 if (children_.size() > 0)
00442 {
00443 double dist;
00444 Node* child;
00445 std::vector<double> distToPivot(children_.size());
00446 std::vector<int> permutation(children_.size());
00447
00448 for (unsigned int i=0; i<permutation.size(); ++i)
00449 permutation[i] = i;
00450 std::random_shuffle(permutation.begin(), permutation.end());
00451
00452 for (unsigned int i=0; i<children_.size(); ++i)
00453 if (permutation[i] >= 0)
00454 {
00455 child = children_[permutation[i]];
00456 distToPivot[permutation[i]] = gnat.distFun_(data, child->pivot_);
00457 insertNeighborK(nbh, k, child->pivot_, distToPivot[permutation[i]]);
00458 if (nbh.size()==k)
00459 {
00460 dist = nbh.top().second;
00461 for (unsigned int j=0; j<children_.size(); ++j)
00462 if (permutation[j] >=0 && i != j &&
00463 (distToPivot[permutation[i]] - dist > child->maxRange_[permutation[j]] ||
00464 distToPivot[permutation[i]] + dist < child->minRange_[permutation[j]]))
00465 permutation[j] = -1;
00466 }
00467 }
00468
00469 dist = nbh.top().second;
00470 for (unsigned int i=0; i<children_.size(); ++i)
00471 if (permutation[i] >= 0)
00472 {
00473 child = children_[permutation[i]];
00474 if (nbh.size()<k ||
00475 (distToPivot[permutation[i]] <= (child->maxRadius_ + dist) &&
00476 distToPivot[permutation[i]] >= (child->minRadius_ - dist)))
00477 nodeQueue.push(std::make_pair(child, distToPivot[permutation[i]]));
00478 }
00479 }
00480 }
00481
00482 void insertNeighborR(NearQueue& nbh, double r, const _T& data, double dist) const
00483 {
00484 if (dist < r)
00485 nbh.push(std::make_pair(&data, dist));
00486 }
00487
00488 void nearestR(const GNAT& gnat, const _T &data, double r, NearQueue& nbh, NodeQueue& nodeQueue) const
00489 {
00490 double dist = r;
00491
00492 for (unsigned int i=0; i<data_.size(); ++i)
00493 insertNeighborR(nbh, r, data_[i], gnat.distFun_(data, data_[i]));
00494 if (children_.size() > 0)
00495 {
00496 Node* child;
00497 std::vector<double> distToPivot(children_.size());
00498 std::vector<int> permutation(children_.size());
00499
00500 for (unsigned int i=0; i<permutation.size(); ++i)
00501 permutation[i] = i;
00502 std::random_shuffle(permutation.begin(), permutation.end());
00503
00504 for (unsigned int i=0; i<children_.size(); ++i)
00505 if (permutation[i] >= 0)
00506 {
00507 child = children_[permutation[i]];
00508 distToPivot[i] = gnat.distFun_(data, child->pivot_);
00509 insertNeighborR(nbh, r, child->pivot_, distToPivot[i]);
00510 for (unsigned int j=0; j<children_.size(); ++j)
00511 if (permutation[j] >=0 && i != j &&
00512 (distToPivot[i] - dist > child->maxRange_[permutation[j]] ||
00513 distToPivot[i] + dist < child->minRange_[permutation[j]]))
00514 permutation[j] = -1;
00515 }
00516
00517 for (unsigned int i=0; i<children_.size(); ++i)
00518 if (permutation[i] >= 0)
00519 {
00520 child = children_[permutation[i]];
00521 if ((distToPivot[i] <= (child->maxRadius_ + dist) &&
00522 distToPivot[i] >= (child->minRadius_ - dist)))
00523 nodeQueue.push(std::make_pair(child, distToPivot[i]));
00524 }
00525 }
00526 }
00527
00528 void list(std::vector<_T> &data) const
00529 {
00530 data.push_back(pivot_);
00531 for (unsigned int i=0; i<data_.size(); ++i)
00532 data.push_back(data_[i]);
00533 for (unsigned int i=0; i<children_.size(); ++i)
00534 children_[i]->list(data);
00535 }
00536
00537 friend std::ostream& operator<<(std::ostream& out, const Node& node)
00538 {
00539 out << "\ndegree:\t" << node.degree_;
00540 out << "\nminRadius:\t" << node.minRadius_;
00541 out << "\nmaxRadius:\t" << node.maxRadius_;
00542 out << "\nminRange:\t";
00543 for (unsigned int i=0; i<node.minRange_.size(); ++i)
00544 out << node.minRange_[i] << '\t';
00545 out << "\nmaxRange: ";
00546 for (unsigned int i=0; i<node.maxRange_.size(); ++i)
00547 out << node.maxRange_[i] << '\t';
00548 out << "\npivot:\t" << node.pivot_;
00549 out << "\ndata: ";
00550 for (unsigned int i=0; i<node.data_.size(); ++i)
00551 out << node.data_[i] << '\t';
00552 out << "\nthis:\t" << &node;
00553 out << "\nchildren:\n";
00554 for (unsigned int i=0; i<node.children_.size(); ++i)
00555 out << node.children_[i] << '\t';
00556 out << '\n';
00557 for (unsigned int i=0; i<node.children_.size(); ++i)
00558 out << *node.children_[i] << '\n';
00559 return out;
00560 }
00561
00562 unsigned int degree_;
00563 const _T pivot_;
00564 double minRadius_;
00565 double maxRadius_;
00566 std::vector<double> minRange_;
00567 std::vector<double> maxRange_;
00568 std::vector<_T> data_;
00569 std::vector<Node*> children_;
00570 };
00571
00572
00574 Node* tree_;
00575
00576 unsigned int degree_;
00577 unsigned int minDegree_;
00578 unsigned int maxDegree_;
00579 unsigned int maxNumPtsPerLeaf_;
00580 std::size_t size_;
00581
00583 GreedyKCenters<_T> pivotSelector_;
00584
00586 std::vector<const _T*> removed_;
00587
00588 };
00589
00590 }
00591
00592 #endif