27#ifndef EASY3D_CORE_SPLINE_INTERPOLATION_H
28#define EASY3D_CORE_SPLINE_INTERPOLATION_H
35#include <easy3d/util/logging.h>
98 left_value_(0), right_value_(0),
99 linear_extrapolation_(false) {
106 bool linear_extrapolation =
false);
111 void set_data(
const std::vector<FT> &x,
const std::vector<FT> &y,
bool cubic_spline =
true);
120 std::vector<FT> x_, y_;
123 std::vector<FT> a_, b_, c_;
126 FT left_value_, right_value_;
127 bool linear_extrapolation_;
139 template<
typename FT>
143 BandMatrix() =
default;
145 BandMatrix(
int dim,
int n_u,
int n_l);
147 ~BandMatrix() =
default;
149 void resize(
int dim,
int n_u,
int n_l);
152 int num_upper()
const {
return m_upper.size() - 1; }
153 int num_lower()
const {
return m_lower.size() - 1; }
156 FT &operator()(
int i,
int j);
158 const FT &operator()(
int i,
int j)
const;
161 FT &saved_diag(
int i);
162 const FT &saved_diag(
int i)
const;
168 std::vector<FT> r_solve(
const std::vector<FT> &b)
const;
170 std::vector<FT> l_solve(
const std::vector<FT> &b)
const;
171 std::vector<FT> lu_solve(
const std::vector<FT> &b,
bool is_lu_decomposed =
false);
174 std::vector<std::vector<FT>> m_upper;
175 std::vector<std::vector<FT>> m_lower;
184 template<
typename FT>
187 bool linear_extrapolation) {
188 assert(x_.size() == 0);
191 left_value_ = left_value;
192 right_value_ = right_value;
193 linear_extrapolation_ = linear_extrapolation;
196 template<
typename FT>
198 if (x.size() != y.size()) {
199 LOG(ERROR) <<
"sizes of x (" << x.size() <<
") and y (" << y.size() <<
") do not match";
203 LOG(ERROR) <<
"too few data (size of x: " << x.size() <<
")";
208 const int n =
static_cast<int>(x.size());
210 for (std::size_t i = 0; i < n - 1; i++) {
211 if (x_[i] >= x_[i + 1]) {
212 LOG_N_TIMES(3, ERROR) <<
"x has to be monotonously increasing (x[" << i <<
"]=" << x_[i] <<
", x[" << i + 1 <<
"]=" << x_[i + 1] <<
"). " << COUNTER;
220 BandMatrix<FT> A(n, 1, 1);
221 std::vector<FT> rhs(n);
222 for (
int i = 1; i < n - 1; i++) {
223 A(i, i - 1) = FT(1.0 / 3.0) * (x[i] - x[i - 1]);
224 A(i, i) = FT(2.0 / 3.0) * (x[i + 1] - x[i - 1]);
225 A(i, i + 1) = FT(1.0 / 3.0) * (x[i + 1] - x[i]);
226 rhs[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1]);
233 rhs[0] = left_value_;
237 A(0, 0) = FT(2.0) * (x[1] - x[0]);
238 A(0, 1) = FT(1.0) * (x[1] - x[0]);
239 rhs[0] = FT(3.0) * ((y[1] - y[0]) / (x[1] - x[0]) - left_value_);
245 A(n - 1, n - 1) = FT(2.0);
246 A(n - 1, n - 2) = FT(0.0);
247 rhs[n - 1] = right_value_;
252 A(n - 1, n - 1) = FT(2.0) * (x[n - 1] - x[n - 2]);
253 A(n - 1, n - 2) = FT(1.0) * (x[n - 1] - x[n - 2]);
254 rhs[n - 1] = FT(3.0) * (right_value_ - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2]));
260 b_ = A.lu_solve(rhs);
265 for (std::size_t i = 0; i < n - 1; i++) {
266 a_[i] = FT(1.0 / 3.0) * (b_[i + 1] - b_[i]) / (x[i + 1] - x[i]);
267 c_[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i])
268 - FT(1.0 / 3.0) * (FT(2.0) * b_[i] + b_[i + 1]) * (x[i + 1] - x[i]);
274 for (std::size_t i = 0; i < n - 1; i++) {
277 c_[i] = (y_[i + 1] - y_[i]) / (x_[i + 1] - x_[i]);
282 b0_ = linear_extrapolation_ ? FT(0.0) : b_[0];
287 FT h = x[n - 1] - x[n - 2];
290 c_[n - 1] = FT(3.0) * a_[n - 2] * h * h + FT(2.0) * b_[n - 2] * h + c_[n - 2];
291 if (linear_extrapolation_)
295 template<
typename FT>
297 size_t n = x_.size();
299 typename std::vector<FT>::const_iterator it;
300 it = std::lower_bound(x_.begin(), x_.end(), x);
301 int idx = std::max(
int(it - x_.begin()) - 1, 0);
307 interpol = (b0_ * h + c0_) * h + y_[0];
308 }
else if (x > x_[n - 1]) {
310 interpol = (b_[n - 1] * h + c_[n - 1]) * h + y_[n - 1];
313 interpol = ((a_[idx] * h + b_[idx]) * h + c_[idx]) * h + y_[idx];
318 template<
typename FT>
322 size_t n = x_.size();
324 typename std::vector<FT>::const_iterator it;
325 it = std::lower_bound(x_.begin(), x_.end(), x);
326 int idx = std::max(
int(it - x_.begin()) - 1, 0);
334 interpol = FT(2.0) * b0_ * h + c0_;
337 interpol = FT(2.0) * b0_ * h;
343 }
else if (x > x_[n - 1]) {
347 interpol = FT(2.0) * b_[n - 1] * h + c_[n - 1];
350 interpol = FT(2.0) * b_[n - 1];
360 interpol = (FT(3.0) * a_[idx] * h + FT(2.0) * b_[idx]) * h + c_[idx];
363 interpol = FT(6.0) * a_[idx] * h + FT(2.0) * b_[idx];
366 interpol = FT(6.0) * a_[idx];
381 template<
typename FT>
382 BandMatrix<FT>::BandMatrix(
int dim,
int n_u,
int n_l) {
383 resize(dim, n_u, n_l);
386 template<
typename FT>
387 void BandMatrix<FT>::resize(
int dim,
int n_u,
int n_l) {
391 m_upper.resize(n_u + 1);
392 m_lower.resize(n_l + 1);
393 for (
size_t i = 0; i < m_upper.size(); i++) {
394 m_upper[i].resize(dim);
396 for (
size_t i = 0; i < m_lower.size(); i++) {
397 m_lower[i].resize(dim);
401 template<
typename FT>
402 int BandMatrix<FT>::dim()
const {
403 if (m_upper.size() > 0) {
404 return m_upper[0].size();
413 template<
typename FT>
414 FT &BandMatrix<FT>::operator()(
int i,
int j) {
416 assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
417 assert((-num_lower() <= k) && (k <= num_upper()));
419 if (k >= 0)
return m_upper[k][i];
420 else return m_lower[-k][i];
423 template<
typename FT>
424 const FT &BandMatrix<FT>::operator()(
int i,
int j)
const {
426 assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
427 assert((-num_lower() <= k) && (k <= num_upper()));
429 if (k >= 0)
return m_upper[k][i];
430 else return m_lower[-k][i];
434 template<
typename FT>
435 const FT &BandMatrix<FT>::saved_diag(
int i)
const {
436 assert((i >= 0) && (i < dim()));
437 return m_lower[0][i];
440 template<
typename FT>
441 FT &BandMatrix<FT>::saved_diag(
int i) {
442 assert((i >= 0) && (i < dim()));
443 return m_lower[0][i];
447 template<
typename FT>
448 void BandMatrix<FT>::lu_decompose() {
455 for (
int i = 0; i < this->dim(); i++) {
456 assert(this->
operator()(i, i) != 0.0);
457 this->saved_diag(i) = FT(1.0) / this->operator()(i, i);
458 j_min = std::max(0, i - this->num_lower());
459 j_max = std::min(this->dim() - 1, i + this->num_upper());
460 for (
int j = j_min; j <= j_max; j++) {
461 this->operator()(i, j) *= this->saved_diag(i);
463 this->operator()(i, i) = FT(1.0);
467 for (
int k = 0; k < this->dim(); k++) {
468 i_max = std::min(this->dim() - 1, k + this->num_lower());
469 for (
int i = k + 1; i <= i_max; i++) {
470 assert(this->
operator()(k, k) != FT(0.0));
471 x = -this->operator()(i, k) / this->operator()(k, k);
472 this->operator()(i, k) = -x;
473 j_max = std::min(this->dim() - 1, k + this->num_upper());
474 for (
int j = k + 1; j <= j_max; j++) {
476 this->operator()(i, j) = this->operator()(i, j) + x * this->operator()(k, j);
483 template<
typename FT>
484 std::vector<FT> BandMatrix<FT>::l_solve(
const std::vector<FT> &b)
const {
485 assert(this->dim() == (
int) b.size());
486 std::vector<FT> x(this->dim());
489 for (
int i = 0; i < this->dim(); i++) {
491 j_start = std::max(0, i - this->num_lower());
492 for (
int j = j_start; j < i; j++)
sum += this->
operator()(i, j) * x[j];
493 x[i] = (b[i] * this->saved_diag(i)) -
sum;
499 template<
typename FT>
500 std::vector<FT> BandMatrix<FT>::r_solve(
const std::vector<FT> &b)
const {
501 assert(this->dim() == (
int) b.size());
502 std::vector<FT> x(this->dim());
505 for (
int i = this->dim() - 1; i >= 0; i--) {
507 j_stop = std::min(this->dim() - 1, i + this->num_upper());
508 for (
int j = i + 1; j <= j_stop; j++)
sum += this->
operator()(i, j) * x[j];
509 x[i] = (b[i] -
sum) / this->
operator()(i, i);
514 template<
typename FT>
515 std::vector<FT> BandMatrix<FT>::lu_solve(
const std::vector<FT> &b,
bool is_lu_decomposed) {
516 assert(this->dim() == (
int) b.size());
517 std::vector<FT> x, y;
518 if (!is_lu_decomposed) {
519 this->lu_decompose();
521 y = this->l_solve(b);
522 x = this->r_solve(y);
SplineInterpolation()
Definition spline_interpolation.h:97
void set_data(const std::vector< FT > &x, const std::vector< FT > &y, bool cubic_spline=true)
Definition spline_interpolation.h:197
FT operator()(FT x) const
Evaluates the spline at x.
Definition spline_interpolation.h:296
FT derivative(int order, FT x) const
Returns the order -th derivative of the spline at x.
Definition spline_interpolation.h:319
void set_boundary(BoundaryType left, FT left_value, BoundaryType right, FT right_value, bool linear_extrapolation=false)
Definition spline_interpolation.h:185
BoundaryType
Boundary condition type.
Definition spline_interpolation.h:89
@ first_deriv
first derivative
Definition spline_interpolation.h:90
@ second_deriv
second derivative
Definition spline_interpolation.h:91
Definition collider.cpp:182
std::vector< FT > sum(const Matrix< FT > &)
Definition matrix.h:1485