最小二乘法多项式拟合的Java实现
背景
由项目中需要根据一些已有数据学习出一个y=ax+b的一元二项式,给定了x,y的一些样本数据,通过梯度下降或最小二乘法做多项式拟合得到a、b,解决该问题时,首先想到的是通过spark mllib去学习,可是结果并不理想:少量的文档,参数也很难调整。于是转变了解决问题的方式:采用了最小二乘法做多项式拟合。
最小二乘法多项式拟合描述下: (以下参考:https://blog.csdn.net/funnyrand/article/details/46742561)
假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,也就是计算多项式的各项系数 a0, a1, ... an.
根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B,
所以从编程的角度来说需要做两件事情:
1)确定线性方程组的各个系数:
确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:
private void compute() {
...
}
2)解线性方程组:
解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:
private double[] calcLinearEquation(double[][] a, double[] b) {
...
}
Java代码
1 public class JavaLeastSquare {2 private double[] x;
3 private double[] y;
4 private double[] weight;
5 private int n;
6 private double[] coefficient;
7
8 /**
9 * Constructor method.
10 * @param x Array of x
11 * @param y Array of y
12 * @param n The order of polynomial
13 */
14 public JavaLeastSquare(double[] x, double[] y, int n) {
15 if (x == null || y == null || x.length < 2 || x.length != y.length
16 || n < 2) {
17 throw new IllegalArgumentException(
18 "IllegalArgumentException occurred.");
19 }
20 this.x = x;
21 this.y = y;
22 this.n = n;
23 weight = new double[x.length];
24 for (int i = 0; i < x.length; i++) {
25 weight[i] = 1;
26 }
27 compute();
28 }
29
30 /**
31 * Constructor method.
32 * @param x Array of x
33 * @param y Array of y
34 * @param weight Array of weight
35 * @param n The order of polynomial
36 */
37 public JavaLeastSquare(double[] x, double[] y, double[] weight, int n) {
38 if (x == null || y == null || weight == null || x.length < 2
39 || x.length != y.length || x.length != weight.length || n < 2) {
40 throw new IllegalArgumentException(
41 "IllegalArgumentException occurred.");
42 }
43 this.x = x;
44 this.y = y;
45 this.n = n;
46 this.weight = weight;
47 compute();
48 }
49
50 /**
51 * Get coefficient of polynomial.
52 * @return coefficient of polynomial
53 */
54 public double[] getCoefficient() {
55 return coefficient;
56 }
57
58 /**
59 * Used to calculate value by given x.
60 * @param x x
61 * @return y
62 */
63 public double fit(double x) {
64 if (coefficient == null) {
65 return 0;
66 }
67 double sum = 0;
68 for (int i = 0; i < coefficient.length; i++) {
69 sum += Math.pow(x, i) * coefficient[i];
70 }
71 return sum;
72 }
73
74 /**
75 * Use Newton's method to solve equation.
76 * @param y y
77 * @return x
78 */
79 public double solve(double y) {
80 return solve(y, 1.0d);
81 }
82
83 /**
84 * Use Newton's method to solve equation.
85 * @param y y
86 * @param startX The start point of x
87 * @return x
88 */
89 public double solve(double y, double startX) {
90 final double EPS = 0.0000001d;
91 if (coefficient == null) {
92 return 0;
93 }
94 double x1 = 0.0d;
95 double x2 = startX;
96 do {
97 x1 = x2;
98 x2 = x1 - (fit(x1) - y) / calcReciprocal(x1);
99 } while (Math.abs((x1 - x2)) > EPS);
100 return x2;
101 }
102
103 /*
104 * Calculate the reciprocal of x.
105 * @param x x
106 * @return the reciprocal of x
107 */
108 private double calcReciprocal(double x) {
109 if (coefficient == null) {
110 return 0;
111 }
112 double sum = 0;
113 for (int i = 1; i < coefficient.length; i++) {
114 sum += i * Math.pow(x, i - 1) * coefficient[i];
115 }
116 return sum;
117 }
118
119 /*
120 * This method is used to calculate each elements of augmented matrix.
121 */
122 private void compute() {
123 if (x == null || y == null || x.length <= 1 || x.length != y.length
124 || x.length < n || n < 2) {
125 return;
126 }
127 double[] s = new double[(n - 1) * 2 + 1];
128 for (int i = 0; i < s.length; i++) {
129 for (int j = 0; j < x.length; j++) {
130 s[i] += Math.pow(x[j], i) * weight[j];
131 }
132 }
133 double[] b = new double[n];
134 for (int i = 0; i < b.length; i++) {
135 for (int j = 0; j < x.length; j++) {
136 b[i] += Math.pow(x[j], i) * y[j] * weight[j];
137 }
138 }
139 double[][] a = new double[n][n];
140 for (int i = 0; i < n; i++) {
141 for (int j = 0; j < n; j++) {
142 a[i][j] = s[i + j];
143 }
144 }
145
146 // Now we need to calculate each coefficients of augmented matrix
147 coefficient = calcLinearEquation(a, b);
148 }
149
150 /*
151 * Calculate linear equation.
152 * The matrix equation is like this: Ax=B
153 * @param a two-dimensional array
154 * @param b one-dimensional array
155 * @return x, one-dimensional array
156 */
157 private double[] calcLinearEquation(double[][] a, double[] b) {
158 if (a == null || b == null || a.length == 0 || a.length != b.length) {
159 return null;
160 }
161
162 for (double[] x : a) {
163 if (x == null || x.length != a.length)
164 return null;
165 }
166
167 int len = a.length - 1;
168 double[] result = new double[a.length];
169
170 if (len == 0) {
171 result[0] = b[0] / a[0][0];
172 return result;
173 }
174
175 double[][] aa = new double[len][len];
176 double[] bb = new double[len];
177 int posx = -1, posy = -1;
178 for (int i = 0; i <= len; i++) {
179 for (int j = 0; j <= len; j++)
180 if (a[i][j] != 0.0d) {
181 posy = j;
182 break;
183 }
184 if (posy != -1) {
185 posx = i;
186 break;
187 }
188 }
189 if (posx == -1) {
190 return null;
191 }
192
193 int count = 0;
194 for (int i = 0; i <= len; i++) {
195 if (i == posx) {
196 continue;
197 }
198 bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];
199 int count2 = 0;
200 for (int j = 0; j <= len; j++) {
201 if (j == posy) {
202 continue;
203 }
204 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j] * a[i][posy];
205 count2++;
206 }
207 count++;
208 }
209
210 // Calculate sub linear equation
211 double[] result2 = calcLinearEquation(aa, bb);
212
213 // After sub linear calculation, calculate the current coefficient
214 double sum = b[posx];
215 count = 0;
216 for (int i = 0; i <= len; i++) {
217 if (i == posy) {
218 continue;
219 }
220 sum -= a[posx][i] * result2[count];
221 result[i] = result2[count];
222 count++;
223 }
224 result[posy] = sum / a[posx][posy];
225 return result;
226 }
227
228 public static void main(String[] args) {
229 JavaLeastSquare eastSquareMethod = new JavaLeastSquare(
230 new double[]{
231 2, 14, 20, 25, 26, 34,
232 47, 87, 165, 265, 365, 465,
233 565, 665
234 },
235 new double[]{
236 0.7 * 2 + 20 + 0.4,
237 0.7 * 14 + 20 + 0.5,
238 0.7 * 20 + 20 + 3.4,
239 0.7 * 25 + 20 + 5.8,
240 0.7 * 26 + 20 + 8.27,
241 0.7 * 34 + 20 + 0.4,
242
243 0.7 * 47 + 20 + 0.1,
244 0.7 * 87 + 20,
245 0.7 * 165 + 20,
246 0.7 * 265 + 20,
247 0.7 * 365 + 20,
248 0.7 * 465 + 20,
249
250 0.7 * 565 + 20,
251 0.7 * 665 + 20
252 },
253 2);
254
255 double[] coefficients = eastSquareMethod.getCoefficient();
256 for (double c : coefficients) {
257 System.out.println(c);
258 }
259
260 // 测试
261 System.out.println(eastSquareMethod.fit(4));
262 }
263 }
输出结果:
com.datangmobile.biz.leastsquare.JavaLeastSquare
22.27966881467629
0.6952475907448203
25.06065917765557
Process finished with exit code 0
使用开源库
也可使用Apache开源库commons math(http://commons.apache.org/proper/commons-math/userguide/fitting.html),提供的功能更强大:
<dependency><groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.5</version>
</dependency>
实现代码:
import org.apache.commons.math3.fitting.PolynomialCurveFitter;import org.apache.commons.math3.fitting.WeightedObservedPoints;
public class WeightedObservedPointsTest {
public static void main(String[] args) {
final WeightedObservedPoints obs = new WeightedObservedPoints();
obs.add(2, 0.7 * 2 + 20 + 0.4);
obs.add(12, 0.7 * 12 + 20 + 0.3);
obs.add(32, 0.7 * 32 + 20 + 3.4);
obs.add(34 , 0.7 * 34 + 20 + 5.8);
obs.add(58 , 0.7 * 58 + 20 + 8.4);
obs.add(43 , 0.7 * 43 + 20 + 0.28);
obs.add(27 , 0.7 * 27 + 20 + 0.4);
// Instantiate a two-degree polynomial fitter.
final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(2);
// Retrieve fitted parameters (coefficients of the polynomial function).
final double[] coeff = fitter.fit(obs.toList());
for (double c : coeff) {
System.out.println(c);
}
}
}
测试输出结果:
20.47425047847121
0.6749744063035112
0.002523043547711147
Process finished with exit code 0
使用org.ujmp(矩阵)实现最小二乘法:
pom.xml中需要引入org.ujmp
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.dtgroup</groupId>
<artifactId>dtgroup</artifactId>
<version>0.0.1-SNAPSHOT</version>
<repositories>
<repository>
<id>limaven</id>
<name>aliyun maven</name>
<url>http://maven.aliyun.com/nexus/content/groups/public/</url>
<layout>default</layout>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.ujmp</groupId>
<artifactId>ujmp-core</artifactId>
<version>0.3.0</version>
</dependency>
</dependencies>
</project>
java代码:
/*** 采用最小二乘法多项式拟合方式,获取多项式的系数。
* @param sampleCount 采样点个数
* @param fetureCount 多项式的系数
* @param samples 采样点集合
* **/
private static void leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {
// 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵
Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);
for (int i = 0; i < samples.size(); i++) {
matrixX.setAsDouble(samples.get(i).getX(), i, 1);
}
// System.out.println(matrixX);
System.out.println("--------------------------------------");
// 构件 2*2矩阵 存储X
Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);
for (int i = 0; i < samples.size(); i++) {
matrixY.setAsDouble(samples.get(i).getY(), i, 0);
}
// System.out.println(matrixY);
// 对X进行转置
Matrix matrixXTrans = matrixX.transpose();
// System.out.println(matrixXTrans);
// 乘积运算:x*转转置后x:matrixXTrans*matrixX
Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);
System.out.println(matrixMtimes);
System.out.println("--------------------------------------");
// 求逆
Matrix matrixMtimesInv = matrixMtimes.inv();
System.out.println(matrixMtimesInv);
// x转置后结果*求逆结果
System.out.println("--------------------------------------");
Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);
System.out.println(matrixMtimesInvMtimes);
System.out.println("--------------------------------------");
Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);
System.out.println(theta);
}
测试代码:
public static void main(String[] args) {/**
* y=ax+b
*
* a(0,1] b[5,20]
*
* x[0,500] y>=5
*/
// y= 0.8d*x+15
// 当x不变动时,y对应有多个值;此时把y求均值。
List<Sample> samples = new ArrayList<Sample>();
samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));
samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));
samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));
samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));
samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));
samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));
samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));
int sampleCount = samples.size();
int fetureCout = 2;
leastsequare(sampleCount, fetureCout, samples);
}
过滤样本中的噪点:
public static void main(String[] args) {/**
* y=ax+b
*
* a(0,1] b[5,20]
*
* x[0,500] y>=5
*/
// y= 0.8d*x+15
// 当x不变动时,y对应有多个值;此时把y求均值。
List<Sample> samples = new ArrayList<Sample>();
samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));
samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));
samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));
samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));
samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));
samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));
samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));
// samples = filterSample(samples);
sortSample(samples);
FilterSampleByGradientResult result = filterSampleByGradient(0, samples);
while (result.isComplete() == false) {
List<Sample> newSamples=result.getSamples();
sortSample(newSamples);
result = filterSampleByGradient(result.getIndex(), newSamples);
}
samples = result.getSamples();
for (Sample sample : samples) {
System.out.println(sample);
}
int sampleCount = samples.size();
int fetureCout = 2;
leastsequare(sampleCount, fetureCout, samples);
}
/**
* 对采样点进行排序,按照x排序,升序排列
* @param samples 采样点集合
* **/
private static void sortSample(List<Sample> samples) {
samples.sort(new Comparator<Sample>() {
public int compare(Sample o1, Sample o2) {
if (o1.getX() > o2.getX()) {
return 1;
} else if (o1.getX() <= o2.getX()) {
return -1;
}
return 0;
}
});
}
/**
* 过滤采样点中的噪点(采样过滤方式:double theta=(y2-y1)/(x2-x1),theta就是一个斜率,根据该值范围来过滤。)
* @param index 记录上次过滤索引
* @param samples 采样点集合(将从其中过滤掉噪点)
* **/
private static FilterSampleByGradientResult filterSampleByGradient(int index, List<Sample> samples) {
int sampleSize = samples.size();
for (int i = index; i < sampleSize - 1; i++) {
double delta_x = samples.get(i).getX() - samples.get(i + 1).getX();
double delta_y = samples.get(i).getY() - samples.get(i + 1).getY();
// 距离小于2米
if (Math.abs(delta_x) < 1) {
double newY = (samples.get(i).getY() + samples.get(i + 1).getY()) / 2;
double newX = samples.get(i).getX();
samples.remove(i);
samples.remove(i + 1);
samples.add(new Sample(newY, newX));
return new FilterSampleByGradientResult(false, i, samples);
} else {
double gradient = delta_y / delta_x;
if (gradient > 1.5) {
if (i == 0) {
// double newY = (samples.get(i).getY() + samples.get(i
// + 1).getY()) / 2;
// double newX = (samples.get(i).getX() + samples.get(i
// + 1).getX()) / 2;
// samples.remove(i);
// samples.add(new Sample(newY, newX));
} else {
samples.remove(i + 1);
}
return new FilterSampleByGradientResult(false, i, samples);
}
}
}
return new FilterSampleByGradientResult(true, 0, samples);
}
使用距离来处理过滤:
private static List<Sample> filterSample(List<Sample> samples) {// x={x1,x2,x3...xn}
// u=E(x) ---x的期望(均值)为 u
// 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
// 6为x的标准差,标准差=sqrt(方差)
// 剔除噪点可以采用:
// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
// 另外一种方案,对x/y都做上边的处理,之后如果两个结果为and 或者 or操作来选取是否剔除。
// 用点的方式来过滤数据,求出一个中值点,求其他点到该点的距离。
int sampleCount = samples.size();
double sumX = 0d;
double sumY = 0d;
for (Sample sample : samples) {
sumX += sample.getX();
sumY += sample.getY();
}
// 求中心点
double centerX = (sumX / sampleCount);
double centerY = (sumY / sampleCount);
List<Double> distanItems = new ArrayList<Double>();
// 计算出所有点距离该中心点的距离
for (int i = 0; i < samples.size(); i++) {
Sample sample = samples.get(i);
Double xyPow2 = Math.pow(sample.getX() - centerX, 2) + Math.pow(sample.getY() - centerY, 2);
distanItems.add(Math.sqrt(xyPow2));
}
// 以下对根据距离(所有点距离中心点的距离)进行筛选
double sumDistan = 0d;
double distanceU = 0d;
for (Double distance : distanItems) {
sumDistan += distance;
}
distanceU = sumDistan / sampleCount;
double deltaPowSum = 0d;
double distanceTheta = 0d;
// sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
for (Double distance : distanItems) {
deltaPowSum += Math.pow((distance - distanceU), 2);
}
distanceTheta = Math.sqrt(deltaPowSum);
// 剔除噪点可以采用:
// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
double minDistance = distanceU - 0.5 * distanceTheta;
double maxDistance = distanceU + 0.5 * distanceTheta;
List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();
for (int i = distanItems.size() - 1; i >= 0; i--) {
Double distance = distanItems.get(i);
if (distance <= minDistance || distance >= maxDistance) {
willbeRemoveIdxs.add(i);
System.out.println("will be remove " + i);
}
}
for (int willbeRemoveIdx : willbeRemoveIdxs) {
samples.remove(willbeRemoveIdx);
}
return samples;
}
实际业务测试:
package com.zjanalyse.spark.maths;import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
public class LastSquare {
/**
* y=ax+b a(0,1] b[5,20] x[0,500] y>=5
*/
public static void main(String[] args) {
// y= 0.8d*x+15
// 当x不变动时,y对应有多个值;此时把y求均值。
List<Sample> samples = new ArrayList<Sample>();
samples.add(new Sample(0.8d * 11 + 15 + 1, 11d));
samples.add(new Sample(0.8d * 24 + 15 + 0.8, 24d));
samples.add(new Sample(0.8d * 33 + 15 + 0.7, 33d));
samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
samples.add(new Sample(0.8d * 47 + 15 + 0.3, 47d));
samples.add(new Sample(0.8d * 60 + 15 + 0.4, 60d));
samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
samples.add(new Sample(0.8d * 57 + 15 + 0.3, 57d));
samples.add(new Sample(0.8d * 70 + 60 + 0.3, 70d));
samples.add(new Sample(0.8d * 80 + 60 + 0.3, 80d));
samples.add(new Sample(0.8d * 40 + 30 + 0.3, 40d));
sortSample(samples);
System.out.println("原始样本数据");
for (Sample sample : samples) {
System.out.println(sample);
}
System.out.println("开始“所有点”通过“业务数据取值范围”剔除:");
// 按照业务过滤。。。
filterByBusiness(samples);
System.out.println("结束“所有点”通过“业务数据取值范围”剔除:");
for (Sample sample : samples) {
System.out.println(sample);
}
int sampleCount = samples.size();
int fetureCout = 2;
System.out.println("第一次拟合。。。");
Matrix theta = leastsequare(sampleCount, fetureCout, samples);
double wear_loss = theta.getAsDouble(0, 0);
double path_loss = theta.getAsDouble(1, 0);
System.out.println("wear loss " + wear_loss);
System.out.println("path loss " + path_loss);
System.out.println("开始“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");
samples = filterSample(wear_loss, path_loss, samples);
System.out.println("结束“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");
for (Sample sample : samples) {
System.out.println(sample);
}
System.out.println("第二次拟合。。。");
sampleCount = samples.size();
fetureCout = 2;
if (sampleCount >= 2) {
theta = leastsequare(sampleCount, fetureCout, samples);
wear_loss = theta.getAsDouble(0, 0);
path_loss = theta.getAsDouble(1, 0);
System.out.println("wear loss " + wear_loss);
System.out.println("path loss " + path_loss);
}
System.out.println("complete...");
}
/**
* 按照业务过滤有效值范围
*/
private static void filterByBusiness(List<Sample> samples) {
for (int i = 0; i < samples.size(); i++) {
double x = samples.get(i).getX();
double y = samples.get(i).getY();
if (x >= 500) {
System.out.println(x + " x值超出有效值范围[0,500)");
samples.remove(i);
i--;
}
// y= 0.8d*x+15
else if (y < 0 * x + 5 || y > 1 * x + 30) {
System.out.println(
y + " y值超出有效值范围[(0*x+5),(1*x+30)]其中x=" + x + ",也就是:[" + (0 * x + 5) + "," + (1 * x + 30) + ")");
samples.remove(i);
i--;
}
}
}
/**
* Description 点到直线的距离
*
* @param x1
* 点横坐标
* @param y1
* 点纵坐标
* @param A
* 直线方程一般式系数A
* @param B
* 直线方程一般式系数B
* @param C
* 直线方程一般式系数C
* @return 点到之间的距离
* @see 点0,1到之前y=x+0的距离 <br>
* double distance = getDistanceOfPerpendicular(0,0, -1, 1, 0);<br>
* System.out.println(distance);<br>
*/
private static double getDistanceOfPerpendicular(double x1, double y1, double A, double B, double C) {
double distance = Math.abs((A * x1 + B * y1 + C) / Math.sqrt(A * A + B * B));
return distance;
}
private static List<Sample> filterSample(double wear_loss, double path_loss, List<Sample> samples) {
// x={x1,x2,x3...xn}
// u=E(x) ---x的期望(均值)为 u
// 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
// 6为x的标准差,标准差=sqrt(方差)
// 剔除噪点可以采用:
// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
// 求出所有点距离第一次拟合结果的直线方程的距离
int sampleCount = samples.size();
List<Double> distanItems = new ArrayList<Double>();
// 计算出所有点距离该中心点的距离
for (int i = 0; i < samples.size(); i++) {
Sample sample = samples.get(i);
double distance = getDistanceOfPerpendicular(sample.getX(), sample.getY(), path_loss, -1, wear_loss);
distanItems.add(Math.sqrt(distance));
}
// 以下对根据距离(所有点距离中心点的距离)进行筛选
double sumDistan = 0d;
double distanceU = 0d;
for (Double distance : distanItems) {
sumDistan += distance;
}
distanceU = sumDistan / sampleCount;
double deltaPowSum = 0d;
double distanceTheta = 0d;
// sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
for (Double distance : distanItems) {
deltaPowSum += Math.pow((distance - distanceU), 2);
}
distanceTheta = Math.sqrt(deltaPowSum);
// 剔除噪点可以采用:
// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
double minDistance = distanceU - 0.25 * distanceTheta;
double maxDistance = distanceU + 0.25 * distanceTheta;
List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();
for (int i = distanItems.size() - 1; i >= 0; i--) {
Double distance = distanItems.get(i);
if (distance <= minDistance || distance >= maxDistance) {
System.out.println(distance + " out of range [" + minDistance + "," + maxDistance + "]");
willbeRemoveIdxs.add(i);
} else {
System.out.println(distance);
}
}
for (int willbeRemoveIdx : willbeRemoveIdxs) {
Sample sample = samples.get(willbeRemoveIdx);
System.out.println("remove " + sample);
samples.remove(willbeRemoveIdx);
}
return samples;
}
/**
* 对采样点进行排序,按照x排序,升序排列
*
* @param samples
* 采样点集合
**/
private static void sortSample(List<Sample> samples) {
samples.sort(new Comparator<Sample>() {
public int compare(Sample o1, Sample o2) {
if (o1.getX() > o2.getX()) {
return 1;
} else if (o1.getX() <= o2.getX()) {
return -1;
}
return 0;
}
});
}
/**
* Description 采用最小二乘法多项式拟合方式,获取多项式的系数。
*
* @param sampleCount
* 采样点个数
* @param fetureCount
* 多项式的系数
* @param samples
* 采样点集合
**/
private static Matrix leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {
// 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵
Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);
for (int i = 0; i < samples.size(); i++) {
matrixX.setAsDouble(samples.get(i).getX(), i, 1);
}
// System.out.println(matrixX);
// System.out.println("--------------------------------------");
// 构件 2*2矩阵 存储X
Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);
for (int i = 0; i < samples.size(); i++) {
matrixY.setAsDouble(samples.get(i).getY(), i, 0);
}
// System.out.println(matrixY);
// 对X进行转置
Matrix matrixXTrans = matrixX.transpose();
// System.out.println(matrixXTrans);
// 乘积运算:x*转转置后x:matrixXTrans*matrixX
Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);
// System.out.println(matrixMtimes);
// System.out.println("--------------------------------------");
// 求逆
Matrix matrixMtimesInv = matrixMtimes.inv();
// System.out.println(matrixMtimesInv);
// x转置后结果*求逆结果
// System.out.println("--------------------------------------");
Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);
// System.out.println(matrixMtimesInvMtimes);
// System.out.println("--------------------------------------");
Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);
// System.out.println(theta);
return theta;
}
}
View Code
以上是 最小二乘法多项式拟合的Java实现 的全部内容, 来源链接: utcz.com/z/390259.html