27#ifndef EASY3D_CORE_SPLINE_INTERPOLATION_H
28#define EASY3D_CORE_SPLINE_INTERPOLATION_H
35#include <easy3d/util/logging.h>
96 left_value_(0), right_value_(0),
97 linear_extrapolation_(false) {
103 BoundaryType right, FT right_value,
104 bool linear_extrapolation =
false);
109 void set_data(
const std::vector <FT> &x,
const std::vector <FT> &y,
bool cubic_spline =
true);
118 std::vector <FT> x_, y_;
121 std::vector <FT> a_, b_, c_;
123 BoundaryType left_, right_;
124 FT left_value_, right_value_;
125 bool linear_extrapolation_;
137 template<
typename FT>
141 BandMatrix() =
default;
143 BandMatrix(
int dim,
int n_u,
int n_l);
145 ~BandMatrix() =
default;
147 void resize(
int dim,
int n_u,
int n_l);
150 int num_upper()
const {
return m_upper.size() - 1; }
151 int num_lower()
const {
return m_lower.size() - 1; }
154 FT &operator()(
int i,
int j);
156 const FT &operator()(
int i,
int j)
const;
159 FT &saved_diag(
int i);
160 const FT &saved_diag(
int i)
const;
166 std::vector <FT> r_solve(
const std::vector <FT> &b)
const;
168 std::vector <FT> l_solve(
const std::vector <FT> &b)
const;
169 std::vector <FT> lu_solve(
const std::vector <FT> &b,
bool is_lu_decomposed =
false);
172 std::vector <std::vector<FT>> m_upper;
173 std::vector <std::vector<FT>> m_lower;
182 template<
typename FT>
184 SplineInterpolation<FT>::BoundaryType right, FT right_value,
185 bool linear_extrapolation) {
186 assert(x_.size() == 0);
189 left_value_ = left_value;
190 right_value_ = right_value;
191 linear_extrapolation_ = linear_extrapolation;
194 template<
typename FT>
196 if (x.size() != y.size()) {
197 LOG(ERROR) <<
"sizes of x (" << x.size() <<
") and y (" << y.size() <<
") do not match";
201 LOG(ERROR) <<
"too few data (size of x: " << x.size() <<
")";
206 const int n =
static_cast<int>(x.size());
208 for (std::size_t i = 0; i < n - 1; i++) {
209 if (x_[i] >= x_[i + 1]) {
210 LOG_N_TIMES(3, ERROR) <<
"x has to be monotonously increasing (x[" << i <<
"]=" << x_[i] <<
", x[" << i + 1 <<
"]=" << x_[i + 1] <<
"). " << COUNTER;
218 BandMatrix<FT> A(n, 1, 1);
219 std::vector <FT> rhs(n);
220 for (
int i = 1; i < n - 1; i++) {
221 A(i, i - 1) = FT(1.0 / 3.0) * (x[i] - x[i - 1]);
222 A(i, i) = FT(2.0 / 3.0) * (x[i + 1] - x[i - 1]);
223 A(i, i + 1) = FT(1.0 / 3.0) * (x[i + 1] - x[i]);
224 rhs[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1]);
231 rhs[0] = left_value_;
235 A(0, 0) = FT(2.0) * (x[1] - x[0]);
236 A(0, 1) = FT(1.0) * (x[1] - x[0]);
237 rhs[0] = FT(3.0) * ((y[1] - y[0]) / (x[1] - x[0]) - left_value_);
243 A(n - 1, n - 1) = FT(2.0);
244 A(n - 1, n - 2) = FT(0.0);
245 rhs[n - 1] = right_value_;
250 A(n - 1, n - 1) = FT(2.0) * (x[n - 1] - x[n - 2]);
251 A(n - 1, n - 2) = FT(1.0) * (x[n - 1] - x[n - 2]);
252 rhs[n - 1] = FT(3.0) * (right_value_ - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2]));
258 b_ = A.lu_solve(rhs);
263 for (std::size_t i = 0; i < n - 1; i++) {
264 a_[i] = FT(1.0 / 3.0) * (b_[i + 1] - b_[i]) / (x[i + 1] - x[i]);
265 c_[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i])
266 - FT(1.0 / 3.0) * (FT(2.0) * b_[i] + b_[i + 1]) * (x[i + 1] - x[i]);
272 for (std::size_t i = 0; i < n - 1; i++) {
275 c_[i] = (y_[i + 1] - y_[i]) / (x_[i + 1] - x_[i]);
280 b0_ = linear_extrapolation_ ? FT(0.0) : b_[0];
285 FT h = x[n - 1] - x[n - 2];
288 c_[n - 1] = FT(3.0) * a_[n - 2] * h * h + FT(2.0) * b_[n - 2] * h + c_[n - 2];
289 if (linear_extrapolation_)
293 template<
typename FT>
295 size_t n = x_.size();
297 typename std::vector<FT>::const_iterator it;
298 it = std::lower_bound(x_.begin(), x_.end(), x);
299 int idx = std::max(
int(it - x_.begin()) - 1, 0);
305 interpol = (b0_ * h + c0_) * h + y_[0];
306 }
else if (x > x_[n - 1]) {
308 interpol = (b_[n - 1] * h + c_[n - 1]) * h + y_[n - 1];
311 interpol = ((a_[idx] * h + b_[idx]) * h + c_[idx]) * h + y_[idx];
316 template<
typename FT>
320 size_t n = x_.size();
322 typename std::vector<FT>::const_iterator it;
323 it = std::lower_bound(x_.begin(), x_.end(), x);
324 int idx = std::max(
int(it - x_.begin()) - 1, 0);
332 interpol = FT(2.0) * b0_ * h + c0_;
335 interpol = FT(2.0) * b0_ * h;
341 }
else if (x > x_[n - 1]) {
345 interpol = FT(2.0) * b_[n - 1] * h + c_[n - 1];
348 interpol = FT(2.0) * b_[n - 1];
358 interpol = (FT(3.0) * a_[idx] * h + FT(2.0) * b_[idx]) * h + c_[idx];
361 interpol = FT(6.0) * a_[idx] * h + FT(2.0) * b_[idx];
364 interpol = FT(6.0) * a_[idx];
379 template<
typename FT>
380 BandMatrix<FT>::BandMatrix(
int dim,
int n_u,
int n_l) {
381 resize(dim, n_u, n_l);
384 template<
typename FT>
385 void BandMatrix<FT>::resize(
int dim,
int n_u,
int n_l) {
389 m_upper.resize(n_u + 1);
390 m_lower.resize(n_l + 1);
391 for (
size_t i = 0; i < m_upper.size(); i++) {
392 m_upper[i].resize(dim);
394 for (
size_t i = 0; i < m_lower.size(); i++) {
395 m_lower[i].resize(dim);
399 template<
typename FT>
400 int BandMatrix<FT>::dim()
const {
401 if (m_upper.size() > 0) {
402 return m_upper[0].size();
411 template<
typename FT>
412 FT &BandMatrix<FT>::operator()(
int i,
int j) {
414 assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
415 assert((-num_lower() <= k) && (k <= num_upper()));
417 if (k >= 0)
return m_upper[k][i];
418 else return m_lower[-k][i];
421 template<
typename FT>
422 const FT &BandMatrix<FT>::operator()(
int i,
int j)
const {
424 assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
425 assert((-num_lower() <= k) && (k <= num_upper()));
427 if (k >= 0)
return m_upper[k][i];
428 else return m_lower[-k][i];
432 template<
typename FT>
433 const FT &BandMatrix<FT>::saved_diag(
int i)
const {
434 assert((i >= 0) && (i < dim()));
435 return m_lower[0][i];
438 template<
typename FT>
439 FT &BandMatrix<FT>::saved_diag(
int i) {
440 assert((i >= 0) && (i < dim()));
441 return m_lower[0][i];
445 template<
typename FT>
446 void BandMatrix<FT>::lu_decompose() {
453 for (
int i = 0; i < this->dim(); i++) {
454 assert(this->
operator()(i, i) != 0.0);
455 this->saved_diag(i) = FT(1.0) / this->operator()(i, i);
456 j_min = std::max(0, i - this->num_lower());
457 j_max = std::min(this->dim() - 1, i + this->num_upper());
458 for (
int j = j_min; j <= j_max; j++) {
459 this->operator()(i, j) *= this->saved_diag(i);
461 this->operator()(i, i) = FT(1.0);
465 for (
int k = 0; k < this->dim(); k++) {
466 i_max = std::min(this->dim() - 1, k + this->num_lower());
467 for (
int i = k + 1; i <= i_max; i++) {
468 assert(this->
operator()(k, k) != FT(0.0));
469 x = -this->operator()(i, k) / this->operator()(k, k);
470 this->operator()(i, k) = -x;
471 j_max = std::min(this->dim() - 1, k + this->num_upper());
472 for (
int j = k + 1; j <= j_max; j++) {
474 this->operator()(i, j) = this->operator()(i, j) + x * this->operator()(k, j);
481 template<
typename FT>
482 std::vector <FT> BandMatrix<FT>::l_solve(
const std::vector <FT> &b)
const {
483 assert(this->dim() == (
int) b.size());
484 std::vector <FT> x(this->dim());
487 for (
int i = 0; i < this->dim(); i++) {
489 j_start = std::max(0, i - this->num_lower());
490 for (
int j = j_start; j < i; j++)
sum += this->
operator()(i, j) * x[j];
491 x[i] = (b[i] * this->saved_diag(i)) -
sum;
497 template<
typename FT>
498 std::vector <FT> BandMatrix<FT>::r_solve(
const std::vector <FT> &b)
const {
499 assert(this->dim() == (
int) b.size());
500 std::vector <FT> x(this->dim());
503 for (
int i = this->dim() - 1; i >= 0; i--) {
505 j_stop = std::min(this->dim() - 1, i + this->num_upper());
506 for (
int j = i + 1; j <= j_stop; j++)
sum += this->
operator()(i, j) * x[j];
507 x[i] = (b[i] -
sum) / this->
operator()(i, i);
512 template<
typename FT>
513 std::vector <FT> BandMatrix<FT>::lu_solve(
const std::vector <FT> &b,
bool is_lu_decomposed) {
514 assert(this->dim() == (
int) b.size());
515 std::vector <FT> x, y;
516 if (!is_lu_decomposed) {
517 this->lu_decompose();
519 y = this->l_solve(b);
520 x = this->r_solve(y);
Cubic spline interpolation.
Definition: spline_interpolation.h:85
SplineInterpolation()
Definition: spline_interpolation.h:95
void set_data(const std::vector< FT > &x, const std::vector< FT > &y, bool cubic_spline=true)
Definition: spline_interpolation.h:195
FT operator()(FT x) const
Evaluates the spline at x.
Definition: spline_interpolation.h:294
FT derivative(int order, FT x) const
Returns the order -th derivative of the spline at x.
Definition: spline_interpolation.h:317
void set_boundary(BoundaryType left, FT left_value, BoundaryType right, FT right_value, bool linear_extrapolation=false)
Definition: spline_interpolation.h:183
Definition: collider.cpp:182
std::vector< FT > sum(const Matrix< FT > &)
Definition: matrix.h:1454