20 #if !defined(__MITSUBA_CORE_KDTREE_H_)
21 #define __MITSUBA_CORE_KDTREE_H_
44 template <
typename _Po
intType,
typename _DataRecord>
struct SimpleKDNode {
48 typedef typename PointType::Scalar
Scalar;
55 static const bool leftBalancedLayout =
false;
64 right(0), data(), flags(0) { }
67 right(0), data(data), flags(0) { }
78 #if defined(MTS_DEBUG)
80 SLog(
EError,
"SimpleKDNode::setLeftIndex(): Internal error!");
132 typedef typename PointType::Scalar
Scalar;
139 static const bool leftBalancedLayout =
true;
149 data(data), flags(0) { }
155 #if defined(MTS_DEBUG)
156 if (value != 2*
self + 1)
157 SLog(
EError,
"LeftBalancedKDNode::setLeftIndex(): Internal error!");
165 #if defined(MTS_DEBUG)
166 if (value != 0 && value != 2*
self + 2)
167 SLog(
EError,
"LeftBalancedKDNode::setRightIndex(): Internal error!");
222 typedef typename PointType::Scalar
Scalar;
259 : distSquared(distSquared), index(index) { }
262 std::ostringstream oss;
263 oss <<
"SearchResult[distance=" << std::sqrt(distSquared)
264 <<
", index=" << index <<
"]";
276 std::binary_function<SearchResult, SearchResult, bool> {
288 inline PointKDTree(
size_t nodes = 0, EHeuristic heuristic = ESlidingMidpoint)
289 : m_nodes(nodes), m_heuristic(heuristic), m_depth(0) { }
295 inline void clear() { m_nodes.clear(); m_aabb.reset(); }
297 inline void resize(
size_t size) { m_nodes.resize(size); }
299 inline void reserve(
size_t size) { m_nodes.reserve(size); }
301 inline size_t size()
const {
return m_nodes.size(); }
303 inline size_t capacity()
const {
return m_nodes.capacity(); }
306 m_nodes.push_back(node);
323 inline void setDepth(
size_t depth) { m_depth = depth; }
326 void build(
bool recomputeAABB =
false) {
329 if (m_nodes.size() == 0) {
330 SLog(
EWarn,
"build(): kd-tree is empty!");
335 PointType::dim, m_nodes.size(),
memString(m_nodes.size() *
sizeof(
NodeType)).c_str());
339 for (
size_t i=0; i<m_nodes.size(); ++i)
340 m_aabb.expandBy(m_nodes[i].getPosition());
343 int aabbTime = timer->getMilliseconds();
350 std::vector<IndexType> indirection(m_nodes.size());
351 for (
size_t i=0; i<m_nodes.size(); ++i)
355 int constructionTime;
356 if (NodeType::leftBalancedLayout) {
357 std::vector<IndexType> permutation(m_nodes.size());
358 buildLB(0, 1, indirection.begin(), indirection.begin(),
359 indirection.end(), permutation);
360 constructionTime = timer->getMilliseconds();
364 build(1, indirection.begin(), indirection.begin(), indirection.end());
365 constructionTime = timer->getMilliseconds();
370 int permutationTime = timer->getMilliseconds();
373 SLog(
EDebug,
"Done after %i ms (breakdown: aabb: %i ms, build: %i ms, permute: %i ms). ",
374 aabbTime + constructionTime + permutationTime, aabbTime, constructionTime, permutationTime);
376 SLog(
EDebug,
"Done after %i ms (breakdown: build: %i ms, permute: %i ms). ",
377 constructionTime + permutationTime, constructionTime, permutationTime);
398 size_t k, SearchResult *results)
const {
399 if (m_nodes.size() == 0)
404 Float sqrSearchRadius = _sqrSearchRadius;
405 size_t resultCount = 0;
409 while (stackPos > 0) {
410 const NodeType &node = m_nodes[index];
417 bool searchBoth = distToPlane*distToPlane <= sqrSearchRadius;
419 if (distToPlane > 0) {
422 if (hasRightChild(index)) {
426 }
else if (searchBoth) {
429 nextIndex = stack[--stackPos];
434 if (searchBoth && hasRightChild(index))
440 nextIndex = stack[--stackPos];
446 if (pointDistSquared < sqrSearchRadius) {
449 if (resultCount < k) {
452 results[resultCount++] = SearchResult(pointDistSquared, index);
456 std::make_heap(results, results + resultCount,
457 SearchResultComparator());
460 SearchResult *end = results + resultCount + 1;
463 results[resultCount] = SearchResult(pointDistSquared, index);
464 std::push_heap(results, end, SearchResultComparator());
465 std::pop_heap(results, end, SearchResultComparator());
468 sqrSearchRadius = results[0].distSquared;
473 _sqrSearchRadius = sqrSearchRadius;
496 size_t k, SearchResult *results,
size_t &traversalSteps)
const {
499 if (m_nodes.size() == 0)
504 size_t resultCount = 0;
508 while (stackPos > 0) {
509 const NodeType &node = m_nodes[index];
517 bool searchBoth = distToPlane*distToPlane <= sqrSearchRadius;
519 if (distToPlane > 0) {
522 if (hasRightChild(index)) {
526 }
else if (searchBoth) {
529 nextIndex = stack[--stackPos];
534 if (searchBoth && hasRightChild(index))
540 nextIndex = stack[--stackPos];
546 if (pointDistSquared < sqrSearchRadius) {
549 if (resultCount < k) {
552 results[resultCount++] = SearchResult(pointDistSquared, index);
556 std::make_heap(results, results + resultCount,
557 SearchResultComparator());
562 results[resultCount] = SearchResult(pointDistSquared, index);
563 std::push_heap(results, results + resultCount + 1, SearchResultComparator());
564 std::pop_heap(results, results + resultCount + 1, SearchResultComparator());
567 sqrSearchRadius = results[0].distSquared;
589 SearchResult *results)
const {
590 Float searchRadiusSqr = std::numeric_limits<Float>::infinity();
591 return nnSearch(p, searchRadiusSqr, k, results);
607 Float searchRadius, Functor &functor) {
608 if (m_nodes.size() == 0)
612 size_t index = 0, stackPos = 1, found = 0;
613 Float distSquared = searchRadius*searchRadius;
616 while (stackPos > 0) {
625 bool searchBoth = distToPlane*distToPlane <= distSquared;
627 if (distToPlane > 0) {
630 if (hasRightChild(index)) {
634 }
else if (searchBoth) {
637 nextIndex = stack[--stackPos];
642 if (searchBoth && hasRightChild(index))
648 nextIndex = stack[--stackPos];
654 if (pointDistSquared < distSquared) {
676 Float searchRadius, Functor &functor)
const {
677 if (m_nodes.size() == 0)
681 IndexType index = 0, stackPos = 1, found = 0;
682 Float distSquared = searchRadius*searchRadius;
685 while (stackPos > 0) {
686 const NodeType &node = m_nodes[index];
694 bool searchBoth = distToPlane*distToPlane <= distSquared;
696 if (distToPlane > 0) {
699 if (hasRightChild(index)) {
703 }
else if (searchBoth) {
706 nextIndex = stack[--stackPos];
711 if (searchBoth && hasRightChild(index))
717 nextIndex = stack[--stackPos];
723 if (pointDistSquared < distSquared) {
730 return (
size_t) found;
743 if (m_nodes.size() == 0)
747 IndexType index = 0, stackPos = 1, found = 0;
748 Float distSquared = searchRadius*searchRadius;
751 while (stackPos > 0) {
752 const NodeType &node = m_nodes[index];
760 bool searchBoth = distToPlane*distToPlane <= distSquared;
762 if (distToPlane > 0) {
765 if (hasRightChild(index)) {
769 }
else if (searchBoth) {
772 nextIndex = stack[--stackPos];
777 if (searchBoth && hasRightChild(index))
783 nextIndex = stack[--stackPos];
789 if (pointDistSquared < distSquared) {
791 results.push_back(index);
796 return (
size_t) found;
807 if (NodeType::leftBalancedLayout) {
808 return 2*index+2 < m_nodes.size();
810 return m_nodes[index].getRightIndex(index) != 0;
817 : m_nodes(nodes), m_axis(axis) { }
819 return m_nodes[i1].getPosition()[m_axis] < m_nodes[i2].getPosition()[m_axis];
822 const std::vector<NodeType> &m_nodes;
829 : m_nodes(nodes), m_axis(axis), m_value(value) { }
831 return m_nodes[i].getPosition()[m_axis] <= m_value;
834 const std::vector<NodeType> &m_nodes;
866 if (2*remaining < p) {
870 p = (p >> 1) + remaining;
878 typename std::vector<IndexType>::iterator base,
879 typename std::vector<IndexType>::iterator rangeStart,
880 typename std::vector<IndexType>::iterator rangeEnd,
881 typename std::vector<IndexType> &permutation) {
882 m_depth = std::max(depth, m_depth);
889 m_nodes[*rangeStart].setLeaf(
true);
890 permutation[idx] = *rangeStart;
894 typename std::vector<IndexType>::iterator split
895 = rangeStart + leftSubtreeSize(count);
896 int axis = m_aabb.getLargestAxis();
897 std::nth_element(rangeStart, split, rangeEnd,
898 CoordinateOrdering(m_nodes, axis));
900 NodeType &splitNode = m_nodes[*split];
903 permutation[idx] = *split;
906 Scalar temp = m_aabb.max[axis],
908 m_aabb.max[axis] = splitPos;
909 buildLB(2*idx+1, depth+1, base, rangeStart, split, permutation);
910 m_aabb.max[axis] = temp;
912 if (split+1 != rangeEnd) {
913 temp = m_aabb.min[axis];
914 m_aabb.min[axis] = splitPos;
915 buildLB(2*idx+2, depth+1, base, split+1, rangeEnd, permutation);
916 m_aabb.min[axis] = temp;
922 typename std::vector<IndexType>::iterator base,
923 typename std::vector<IndexType>::iterator rangeStart,
924 typename std::vector<IndexType>::iterator rangeEnd) {
925 m_depth = std::max(depth, m_depth);
932 m_nodes[*rangeStart].setLeaf(
true);
937 typename std::vector<IndexType>::iterator split;
939 switch (m_heuristic) {
941 split = rangeStart + count/2;
942 axis = m_aabb.getLargestAxis();
943 std::nth_element(rangeStart, split, rangeEnd,
944 CoordinateOrdering(m_nodes, axis));
948 case ELeftBalanced: {
949 split = rangeStart + leftSubtreeSize(count);
950 axis = m_aabb.getLargestAxis();
951 std::nth_element(rangeStart, split, rangeEnd,
952 CoordinateOrdering(m_nodes, axis));
956 case ESlidingMidpoint: {
958 axis = m_aabb.getLargestAxis();
961 * (m_aabb.max[axis]+m_aabb.min[axis]);
963 size_t nLT = std::count_if(rangeStart, rangeEnd,
964 LessThanOrEqual(m_nodes, axis, midpoint));
967 split = rangeStart + nLT;
969 if (split == rangeStart)
971 else if (split == rangeEnd)
974 std::nth_element(rangeStart, split, rangeEnd,
975 CoordinateOrdering(m_nodes, axis));
980 Float bestCost = std::numeric_limits<Float>::infinity();
982 for (
int dim=0; dim<PointType::dim; ++dim) {
983 std::sort(rangeStart, rangeEnd,
984 CoordinateOrdering(m_nodes, dim));
986 size_t numLeft = 1, numRight = count-2;
987 AABBType leftAABB(m_aabb), rightAABB(m_aabb);
988 Float invVolume = 1.0f / m_aabb.getVolume();
989 for (
typename std::vector<IndexType>::iterator it = rangeStart+1;
990 it != rangeEnd; ++it) {
991 ++numLeft; --numRight;
992 Float pos = m_nodes[*it].getPosition()[dim];
993 leftAABB.max[dim] = rightAABB.
min[dim] = pos;
995 Float cost = (numLeft * leftAABB.getVolume()
996 + numRight * rightAABB.
getVolume()) * invVolume;
997 if (cost < bestCost) {
1004 std::nth_element(rangeStart, split, rangeEnd,
1005 CoordinateOrdering(m_nodes, axis));
1010 NodeType &splitNode = m_nodes[*split];
1014 if (split+1 != rangeEnd)
1022 std::iter_swap(rangeStart, split);
1025 Scalar temp = m_aabb.max[axis],
1027 m_aabb.max[axis] = splitPos;
1028 build(depth+1, base, rangeStart+1, split+1);
1029 m_aabb.max[axis] = temp;
1031 if (split+1 != rangeEnd) {
1032 temp = m_aabb.min[axis];
1033 m_aabb.min[axis] = splitPos;
1034 build(depth+1, base, split+1, rangeEnd);
1035 m_aabb.min[axis] = temp;
void setLeftIndex(IndexType self, IndexType value)
Given the current node's index, set the left child index.
Definition: kdtree.h:154
uint32_t IndexType
Definition: kdtree.h:47
bool isLeaf() const
Check whether this is a leaf node.
Definition: kdtree.h:172
_PointType PointType
Definition: kdtree.h:129
bool hasRightChild(IndexType index) const
Return whether or not the inner node of the specified index has a right child node.
Definition: kdtree.h:806
void permute_inplace(DataType *data, std::vector< IndexType > &perm)
Apply an arbitrary permutation to an array in linear time.
Definition: util.h:238
Left-balanced kd-tree node for use with PointKDTree.
Definition: kdtree.h:128
size_t executeModifier(const PointType &p, Float searchRadius, Functor &functor)
Execute a search query and run the specified functor on them, which potentially modifies the nodes th...
Definition: kdtree.h:606
const PointType & getPosition() const
Return the position associated with this node.
Definition: kdtree.h:187
SimpleKDNode(const DataRecord &data)
Initialize a KD-tree node with the given data record.
Definition: kdtree.h:66
LeftBalancedKDNode(const DataRecord &data)
Initialize a KD-tree node with the given data record.
Definition: kdtree.h:148
void setAABB(const AABBType &aabb)
Set the AABB of the underlying point data.
Definition: kdtree.h:317
TAABB< PointType > AABBType
Definition: kdtree.h:224
size_t nnSearch(const PointType &p, Float &_sqrSearchRadius, size_t k, SearchResult *results) const
Run a k-nearest-neighbor search query.
Definition: kdtree.h:397
IndexType getLeftIndex(IndexType self) const
Given the current node's index, return the index of the left child.
Definition: kdtree.h:75
size_t search(const PointType &p, Float searchRadius, std::vector< IndexType > &results) const
Run a search query.
Definition: kdtree.h:742
void build(bool recomputeAABB=false)
Construct the KD-tree hierarchy.
Definition: kdtree.h:326
bool operator==(const SearchResult &r) const
Definition: kdtree.h:268
Internal data record used by Photon.
Definition: photon.h:38
DataRecord & getData()
Return the data record associated with this node.
Definition: kdtree.h:192
PointType::Scalar Scalar
Definition: kdtree.h:222
void setDepth(size_t depth)
Set the depth of the constructed KD-tree (be careful with this)
Definition: kdtree.h:323
std::vector< NodeType > m_nodes
Definition: kdtree.h:1039
void setAxis(uint8_t axis)
Set the split flags associated with this node.
Definition: kdtree.h:97
IndexType right
Definition: kdtree.h:58
#define SLog(level, fmt,...)
Write a Log message to the console (static version - to be used outside of classes that derive from O...
Definition: logger.h:49
const AABBType & getAABB() const
Return the AABB of the underlying point data.
Definition: kdtree.h:319
size_t m_depth
Definition: kdtree.h:1042
Debug message, usually turned off.
Definition: formatter.h:30
NodeType::PointType PointType
Definition: kdtree.h:220
PointKDTree(size_t nodes=0, EHeuristic heuristic=ESlidingMidpoint)
Create an empty KD-tree that can hold the specified number of points.
Definition: kdtree.h:288
void push_back(const NodeType &node)
Append a kd-tree node to the node array.
Definition: kdtree.h:305
const DataRecord & getData() const
Return the data record associated with this node (const version)
Definition: kdtree.h:107
void setRightIndex(IndexType self, IndexType value)
Given the current node's index, set the right child index.
Definition: kdtree.h:164
DataRecord data
Definition: kdtree.h:142
Result data type for k-nn queries.
Definition: kdtree.h:252
bool operator()(const IndexType &i1, const IndexType &i2) const
Definition: kdtree.h:818
PointType::Scalar Scalar
Definition: kdtree.h:132
_NodeType NodeType
Definition: kdtree.h:219
Use the sliding midpoint tree construction rule. This ensures that cells do not become overly elongat...
Definition: kdtree.h:238
PointType::VectorType VectorType
Definition: kdtree.h:223
void setAxis(uint8_t axis)
Set the split flags associated with this node.
Definition: kdtree.h:184
_PointType PointType
Definition: kdtree.h:45
_DataRecord DataRecord
Definition: kdtree.h:130
Float distSquared
Definition: kdtree.h:253
uint32_t IndexType
Definition: kdtree.h:131
IndexType getRightIndex(IndexType self) const
Given the current node's index, return the index of the right child.
Definition: kdtree.h:162
size_t getDepth() const
Return the depth of the constructed KD-tree.
Definition: kdtree.h:321
bool operator()(const IndexType &i) const
Definition: kdtree.h:830
void buildLB(IndexType idx, size_t depth, typename std::vector< IndexType >::iterator base, typename std::vector< IndexType >::iterator rangeStart, typename std::vector< IndexType >::iterator rangeEnd, typename std::vector< IndexType > &permutation)
Left-balanced tree construction routine.
Definition: kdtree.h:877
NodeType::IndexType IndexType
Definition: kdtree.h:221
void resize(size_t size)
Resize the kd-tree array.
Definition: kdtree.h:297
size_t size() const
Return the size of the kd-tree.
Definition: kdtree.h:301
size_t nnSearchCollectStatistics(const PointType &p, Float &sqrSearchRadius, size_t k, SearchResult *results, size_t &traversalSteps) const
Run a k-nearest-neighbor search query and record statistics.
Definition: kdtree.h:495
#define SAssert(cond)
``Static'' assertion (to be used outside of classes that derive from Object)
Definition: logger.h:79
AABBType m_aabb
Definition: kdtree.h:1040
uint8_t flags
Definition: kdtree.h:143
void setLeaf(bool value)
Specify whether this is a leaf node.
Definition: kdtree.h:174
Reference counting helper.
Definition: ref.h:40
Warning message.
Definition: formatter.h:32
PointType position
Definition: kdtree.h:141
SearchResult()
Definition: kdtree.h:256
Memory-efficient photon representation for use with PointKDTree.
Definition: photon.h:57
IndexType getRightIndex(IndexType self) const
Given the current node's index, return the index of the right child.
Definition: kdtree.h:70
uint16_t getAxis() const
Return the split axis associated with this node.
Definition: kdtree.h:95
bool operator()(const SearchResult &a, const SearchResult &b) const
Definition: kdtree.h:278
SimpleKDNode()
Initialize a KD-tree node.
Definition: kdtree.h:63
size_t executeQuery(const PointType &p, Float searchRadius, Functor &functor) const
Execute a search query and run the specified functor on them.
Definition: kdtree.h:675
MTS_EXPORT_CORE std::string memString(size_t size, bool precise=false)
Turn a memory size into a human-readable string.
const DataRecord & getData() const
Return the data record associated with this node (const version)
Definition: kdtree.h:194
void setData(const DataRecord &val)
Set the data record associated with this node.
Definition: kdtree.h:196
Platform independent milli/micro/nanosecond timerThis class implements a simple cross-platform timer ...
Definition: timer.h:37
Error message, causes an exception to be thrown.
Definition: formatter.h:33
const PointType & getPosition() const
Return the position associated with this node.
Definition: kdtree.h:100
PointType::Scalar Scalar
Definition: kdtree.h:48
std::string toString() const
Definition: kdtree.h:261
uint16_t getAxis() const
Return the split axis associated with this node.
Definition: kdtree.h:182
Generic multi-dimensional kd-tree data structure for point data.
Definition: kdtree.h:217
void build(size_t depth, typename std::vector< IndexType >::iterator base, typename std::vector< IndexType >::iterator rangeStart, typename std::vector< IndexType >::iterator rangeEnd)
Default tree construction routine.
Definition: kdtree.h:921
LeftBalancedKDNode()
Initialize a KD-tree node.
Definition: kdtree.h:146
LessThanOrEqual(const std::vector< NodeType > &nodes, int axis, Scalar value)
Definition: kdtree.h:828
size_t capacity() const
Return the capacity of the kd-tree.
Definition: kdtree.h:303
PointType min
Component-wise minimum.
Definition: aabb.h:424
void setLeftIndex(IndexType self, IndexType value)
Given the current node's index, set the left child index.
Definition: kdtree.h:77
void setPosition(const PointType &value)
Set the position associated with this node.
Definition: kdtree.h:102
CoordinateOrdering(const std::vector< NodeType > &nodes, int axis)
Definition: kdtree.h:816
DataRecord & getData()
Return the data record associated with this node.
Definition: kdtree.h:105
void setLeaf(bool value)
Specify whether this is a leaf node.
Definition: kdtree.h:87
size_t nnSearch(const PointType &p, size_t k, SearchResult *results) const
Run a k-nearest-neighbor search query without any search radius threshold.
Definition: kdtree.h:588
void setRightIndex(IndexType self, IndexType value)
Given the current node's index, set the right child index.
Definition: kdtree.h:72
Comparison functor for nearest-neighbor search queries.
Definition: kdtree.h:275
uint8_t flags
Definition: kdtree.h:60
bool isLeaf() const
Check whether this is a leaf node.
Definition: kdtree.h:85
Scalar getVolume() const
Calculate the n-dimensional volume of the bounding box.
Definition: aabb.h:107
DataRecord data
Definition: kdtree.h:59
Simple kd-tree node for use with PointKDTree.
Definition: kdtree.h:44
SearchResult(Float distSquared, IndexType index)
Definition: kdtree.h:258
void setPosition(const PointType &value)
Set the position associated with this node.
Definition: kdtree.h:189
_DataRecord DataRecord
Definition: kdtree.h:46
IndexType index
Definition: kdtree.h:254
NodeType & operator[](size_t idx)
Return one of the KD-tree nodes by index.
Definition: kdtree.h:310
const NodeType & operator[](size_t idx) const
Return one of the KD-tree nodes by index (const version)
Definition: kdtree.h:312
void reserve(size_t size)
Reserve a certain amount of memory for the kd-tree array.
Definition: kdtree.h:299
Create a left-balanced tree.
Definition: kdtree.h:232
IndexType leftSubtreeSize(IndexType count) const
Definition: kdtree.h:854
EHeuristic m_heuristic
Definition: kdtree.h:1041
void clear()
Clear the kd-tree array.
Definition: kdtree.h:295
PointType position
Definition: kdtree.h:57
IndexType getLeftIndex(IndexType self) const
Given the current node's index, return the index of the left child.
Definition: kdtree.h:152
void setData(const DataRecord &val)
Set the data record associated with this node.
Definition: kdtree.h:109