最小二乘法多项式拟合的Java实现

java

背景

由项目中需要根据一些已有数据学习出一个y=ax+b的一元二项式,给定了x,y的一些样本数据,通过梯度下降或最小二乘法做多项式拟合得到a、b,解决该问题时,首先想到的是通过spark mllib去学习,可是结果并不理想:少量的文档,参数也很难调整。于是转变了解决问题的方式:采用了最小二乘法做多项式拟合。

最小二乘法多项式拟合描述下: (以下参考:https://blog.csdn.net/funnyrand/article/details/46742561)

假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,也就是计算多项式的各项系数 a0, a1, ... an. 

根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B

所以从编程的角度来说需要做两件事情:

1)确定线性方程组的各个系数:

确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:

private void compute() {
  ...
}

2)解线性方程组:

解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:

private double[] calcLinearEquation(double[][] a, double[] b) {
  ...
}

Java代码

  1 public class JavaLeastSquare {

2 private double[] x;

3 private double[] y;

4 private double[] weight;

5 private int n;

6 private double[] coefficient;

7

8 /**

9 * Constructor method.

10 * @param x Array of x

11 * @param y Array of y

12 * @param n The order of polynomial

13 */

14 public JavaLeastSquare(double[] x, double[] y, int n) {

15 if (x == null || y == null || x.length < 2 || x.length != y.length

16 || n < 2) {

17 throw new IllegalArgumentException(

18 "IllegalArgumentException occurred.");

19 }

20 this.x = x;

21 this.y = y;

22 this.n = n;

23 weight = new double[x.length];

24 for (int i = 0; i < x.length; i++) {

25 weight[i] = 1;

26 }

27 compute();

28 }

29

30 /**

31 * Constructor method.

32 * @param x Array of x

33 * @param y Array of y

34 * @param weight Array of weight

35 * @param n The order of polynomial

36 */

37 public JavaLeastSquare(double[] x, double[] y, double[] weight, int n) {

38 if (x == null || y == null || weight == null || x.length < 2

39 || x.length != y.length || x.length != weight.length || n < 2) {

40 throw new IllegalArgumentException(

41 "IllegalArgumentException occurred.");

42 }

43 this.x = x;

44 this.y = y;

45 this.n = n;

46 this.weight = weight;

47 compute();

48 }

49

50 /**

51 * Get coefficient of polynomial.

52 * @return coefficient of polynomial

53 */

54 public double[] getCoefficient() {

55 return coefficient;

56 }

57

58 /**

59 * Used to calculate value by given x.

60 * @param x x

61 * @return y

62 */

63 public double fit(double x) {

64 if (coefficient == null) {

65 return 0;

66 }

67 double sum = 0;

68 for (int i = 0; i < coefficient.length; i++) {

69 sum += Math.pow(x, i) * coefficient[i];

70 }

71 return sum;

72 }

73

74 /**

75 * Use Newton's method to solve equation.

76 * @param y y

77 * @return x

78 */

79 public double solve(double y) {

80 return solve(y, 1.0d);

81 }

82

83 /**

84 * Use Newton's method to solve equation.

85 * @param y y

86 * @param startX The start point of x

87 * @return x

88 */

89 public double solve(double y, double startX) {

90 final double EPS = 0.0000001d;

91 if (coefficient == null) {

92 return 0;

93 }

94 double x1 = 0.0d;

95 double x2 = startX;

96 do {

97 x1 = x2;

98 x2 = x1 - (fit(x1) - y) / calcReciprocal(x1);

99 } while (Math.abs((x1 - x2)) > EPS);

100 return x2;

101 }

102

103 /*

104 * Calculate the reciprocal of x.

105 * @param x x

106 * @return the reciprocal of x

107 */

108 private double calcReciprocal(double x) {

109 if (coefficient == null) {

110 return 0;

111 }

112 double sum = 0;

113 for (int i = 1; i < coefficient.length; i++) {

114 sum += i * Math.pow(x, i - 1) * coefficient[i];

115 }

116 return sum;

117 }

118

119 /*

120 * This method is used to calculate each elements of augmented matrix.

121 */

122 private void compute() {

123 if (x == null || y == null || x.length <= 1 || x.length != y.length

124 || x.length < n || n < 2) {

125 return;

126 }

127 double[] s = new double[(n - 1) * 2 + 1];

128 for (int i = 0; i < s.length; i++) {

129 for (int j = 0; j < x.length; j++) {

130 s[i] += Math.pow(x[j], i) * weight[j];

131 }

132 }

133 double[] b = new double[n];

134 for (int i = 0; i < b.length; i++) {

135 for (int j = 0; j < x.length; j++) {

136 b[i] += Math.pow(x[j], i) * y[j] * weight[j];

137 }

138 }

139 double[][] a = new double[n][n];

140 for (int i = 0; i < n; i++) {

141 for (int j = 0; j < n; j++) {

142 a[i][j] = s[i + j];

143 }

144 }

145

146 // Now we need to calculate each coefficients of augmented matrix

147 coefficient = calcLinearEquation(a, b);

148 }

149

150 /*

151 * Calculate linear equation.

152 * The matrix equation is like this: Ax=B

153 * @param a two-dimensional array

154 * @param b one-dimensional array

155 * @return x, one-dimensional array

156 */

157 private double[] calcLinearEquation(double[][] a, double[] b) {

158 if (a == null || b == null || a.length == 0 || a.length != b.length) {

159 return null;

160 }

161

162 for (double[] x : a) {

163 if (x == null || x.length != a.length)

164 return null;

165 }

166

167 int len = a.length - 1;

168 double[] result = new double[a.length];

169

170 if (len == 0) {

171 result[0] = b[0] / a[0][0];

172 return result;

173 }

174

175 double[][] aa = new double[len][len];

176 double[] bb = new double[len];

177 int posx = -1, posy = -1;

178 for (int i = 0; i <= len; i++) {

179 for (int j = 0; j <= len; j++)

180 if (a[i][j] != 0.0d) {

181 posy = j;

182 break;

183 }

184 if (posy != -1) {

185 posx = i;

186 break;

187 }

188 }

189 if (posx == -1) {

190 return null;

191 }

192

193 int count = 0;

194 for (int i = 0; i <= len; i++) {

195 if (i == posx) {

196 continue;

197 }

198 bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];

199 int count2 = 0;

200 for (int j = 0; j <= len; j++) {

201 if (j == posy) {

202 continue;

203 }

204 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j] * a[i][posy];

205 count2++;

206 }

207 count++;

208 }

209

210 // Calculate sub linear equation

211 double[] result2 = calcLinearEquation(aa, bb);

212

213 // After sub linear calculation, calculate the current coefficient

214 double sum = b[posx];

215 count = 0;

216 for (int i = 0; i <= len; i++) {

217 if (i == posy) {

218 continue;

219 }

220 sum -= a[posx][i] * result2[count];

221 result[i] = result2[count];

222 count++;

223 }

224 result[posy] = sum / a[posx][posy];

225 return result;

226 }

227

228 public static void main(String[] args) {

229 JavaLeastSquare eastSquareMethod = new JavaLeastSquare(

230 new double[]{

231 2, 14, 20, 25, 26, 34,

232 47, 87, 165, 265, 365, 465,

233 565, 665

234 },

235 new double[]{

236 0.7 * 2 + 20 + 0.4,

237 0.7 * 14 + 20 + 0.5,

238 0.7 * 20 + 20 + 3.4,

239 0.7 * 25 + 20 + 5.8,

240 0.7 * 26 + 20 + 8.27,

241 0.7 * 34 + 20 + 0.4,

242

243 0.7 * 47 + 20 + 0.1,

244 0.7 * 87 + 20,

245 0.7 * 165 + 20,

246 0.7 * 265 + 20,

247 0.7 * 365 + 20,

248 0.7 * 465 + 20,

249

250 0.7 * 565 + 20,

251 0.7 * 665 + 20

252 },

253 2);

254

255 double[] coefficients = eastSquareMethod.getCoefficient();

256 for (double c : coefficients) {

257 System.out.println(c);

258 }

259

260 // 测试

261 System.out.println(eastSquareMethod.fit(4));

262 }

263 }

输出结果:

com.datangmobile.biz.leastsquare.JavaLeastSquare
22.27966881467629
0.6952475907448203
25.06065917765557

Process finished with exit code 0

使用开源库

也可使用Apache开源库commons math(http://commons.apache.org/proper/commons-math/userguide/fitting.html),提供的功能更强大:

<dependency>  

<groupId>org.apache.commons</groupId>

<artifactId>commons-math3</artifactId>

<version>3.5</version>

</dependency>

实现代码:

import org.apache.commons.math3.fitting.PolynomialCurveFitter;

import org.apache.commons.math3.fitting.WeightedObservedPoints;

public class WeightedObservedPointsTest {

public static void main(String[] args) {

final WeightedObservedPoints obs = new WeightedObservedPoints();

obs.add(2, 0.7 * 2 + 20 + 0.4);

obs.add(12, 0.7 * 12 + 20 + 0.3);

obs.add(32, 0.7 * 32 + 20 + 3.4);

obs.add(34 , 0.7 * 34 + 20 + 5.8);

obs.add(58 , 0.7 * 58 + 20 + 8.4);

obs.add(43 , 0.7 * 43 + 20 + 0.28);

obs.add(27 , 0.7 * 27 + 20 + 0.4);

// Instantiate a two-degree polynomial fitter.

final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(2);

// Retrieve fitted parameters (coefficients of the polynomial function).

final double[] coeff = fitter.fit(obs.toList());

for (double c : coeff) {

System.out.println(c);

}

}

}

测试输出结果:

20.47425047847121
0.6749744063035112
0.002523043547711147

Process finished with exit code 0

使用org.ujmp(矩阵)实现最小二乘法:

pom.xml中需要引入org.ujmp

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"

xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">

<modelVersion>4.0.0</modelVersion>

<groupId>com.dtgroup</groupId>

<artifactId>dtgroup</artifactId>

<version>0.0.1-SNAPSHOT</version>

<repositories>

<repository>

<id>limaven</id>

<name>aliyun maven</name>

<url>http://maven.aliyun.com/nexus/content/groups/public/</url>

<layout>default</layout>

<releases>

<enabled>true</enabled>

</releases>

<snapshots>

<enabled>false</enabled>

</snapshots>

</repository>

</repositories>

<dependencies>

<dependency>

<groupId>org.ujmp</groupId>

<artifactId>ujmp-core</artifactId>

<version>0.3.0</version>

</dependency>

</dependencies>

</project>

java代码:

    /**

* 采用最小二乘法多项式拟合方式,获取多项式的系数。

* @param sampleCount 采样点个数

* @param fetureCount 多项式的系数

* @param samples 采样点集合

* **/

private static void leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {

// 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵

Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);

for (int i = 0; i < samples.size(); i++) {

matrixX.setAsDouble(samples.get(i).getX(), i, 1);

}

// System.out.println(matrixX);

System.out.println("--------------------------------------");

// 构件 2*2矩阵 存储X

Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);

for (int i = 0; i < samples.size(); i++) {

matrixY.setAsDouble(samples.get(i).getY(), i, 0);

}

// System.out.println(matrixY);

// 对X进行转置

Matrix matrixXTrans = matrixX.transpose();

// System.out.println(matrixXTrans);

// 乘积运算:x*转转置后x:matrixXTrans*matrixX

Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);

System.out.println(matrixMtimes);

System.out.println("--------------------------------------");

// 求逆

Matrix matrixMtimesInv = matrixMtimes.inv();

System.out.println(matrixMtimesInv);

// x转置后结果*求逆结果

System.out.println("--------------------------------------");

Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);

System.out.println(matrixMtimesInvMtimes);

System.out.println("--------------------------------------");

Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);

System.out.println(theta);

}

测试代码:

    public static void main(String[] args) {

/**

* y=ax+b

*

* a(0,1] b[5,20]

*

* x[0,500] y>=5

*/

// y= 0.8d*x+15

// 当x不变动时,y对应有多个值;此时把y求均值。

List<Sample> samples = new ArrayList<Sample>();

samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));

samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));

samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));

samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));

samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));

samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));

samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));

samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));

samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));

int sampleCount = samples.size();

int fetureCout = 2;

leastsequare(sampleCount, fetureCout, samples);

}

过滤样本中的噪点:

    public static void main(String[] args) {

/**

* y=ax+b

*

* a(0,1] b[5,20]

*

* x[0,500] y>=5

*/

// y= 0.8d*x+15

// 当x不变动时,y对应有多个值;此时把y求均值。

List<Sample> samples = new ArrayList<Sample>();

samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));

samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));

samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));

samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));

samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));

samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));

samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));

samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));

samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));

// samples = filterSample(samples);

sortSample(samples);

FilterSampleByGradientResult result = filterSampleByGradient(0, samples);

while (result.isComplete() == false) {

List<Sample> newSamples=result.getSamples();

sortSample(newSamples);

result = filterSampleByGradient(result.getIndex(), newSamples);

}

samples = result.getSamples();

for (Sample sample : samples) {

System.out.println(sample);

}

int sampleCount = samples.size();

int fetureCout = 2;

leastsequare(sampleCount, fetureCout, samples);

}

/**

* 对采样点进行排序,按照x排序,升序排列

* @param samples 采样点集合

* **/

private static void sortSample(List<Sample> samples) {

samples.sort(new Comparator<Sample>() {

public int compare(Sample o1, Sample o2) {

if (o1.getX() > o2.getX()) {

return 1;

} else if (o1.getX() <= o2.getX()) {

return -1;

}

return 0;

}

});

}

/**

* 过滤采样点中的噪点(采样过滤方式:double theta=(y2-y1)/(x2-x1),theta就是一个斜率,根据该值范围来过滤。)

* @param index 记录上次过滤索引

* @param samples 采样点集合(将从其中过滤掉噪点)

* **/

private static FilterSampleByGradientResult filterSampleByGradient(int index, List<Sample> samples) {

int sampleSize = samples.size();

for (int i = index; i < sampleSize - 1; i++) {

double delta_x = samples.get(i).getX() - samples.get(i + 1).getX();

double delta_y = samples.get(i).getY() - samples.get(i + 1).getY();

// 距离小于2米

if (Math.abs(delta_x) < 1) {

double newY = (samples.get(i).getY() + samples.get(i + 1).getY()) / 2;

double newX = samples.get(i).getX();

samples.remove(i);

samples.remove(i + 1);

samples.add(new Sample(newY, newX));

return new FilterSampleByGradientResult(false, i, samples);

} else {

double gradient = delta_y / delta_x;

if (gradient > 1.5) {

if (i == 0) {

// double newY = (samples.get(i).getY() + samples.get(i

// + 1).getY()) / 2;

// double newX = (samples.get(i).getX() + samples.get(i

// + 1).getX()) / 2;

// samples.remove(i);

// samples.add(new Sample(newY, newX));

} else {

samples.remove(i + 1);

}

return new FilterSampleByGradientResult(false, i, samples);

}

}

}

return new FilterSampleByGradientResult(true, 0, samples);

}

 使用距离来处理过滤:

    private static List<Sample> filterSample(List<Sample> samples) {

// x={x1,x2,x3...xn}

// u=E(x) ---x的期望(均值)为 u

// 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))

// 6为x的标准差,标准差=sqrt(方差)

// 剔除噪点可以采用:

// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。

// 另外一种方案,对x/y都做上边的处理,之后如果两个结果为and 或者 or操作来选取是否剔除。

// 用点的方式来过滤数据,求出一个中值点,求其他点到该点的距离。

int sampleCount = samples.size();

double sumX = 0d;

double sumY = 0d;

for (Sample sample : samples) {

sumX += sample.getX();

sumY += sample.getY();

}

// 求中心点

double centerX = (sumX / sampleCount);

double centerY = (sumY / sampleCount);

List<Double> distanItems = new ArrayList<Double>();

// 计算出所有点距离该中心点的距离

for (int i = 0; i < samples.size(); i++) {

Sample sample = samples.get(i);

Double xyPow2 = Math.pow(sample.getX() - centerX, 2) + Math.pow(sample.getY() - centerY, 2);

distanItems.add(Math.sqrt(xyPow2));

}

// 以下对根据距离(所有点距离中心点的距离)进行筛选

double sumDistan = 0d;

double distanceU = 0d;

for (Double distance : distanItems) {

sumDistan += distance;

}

distanceU = sumDistan / sampleCount;

double deltaPowSum = 0d;

double distanceTheta = 0d;

// sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))

for (Double distance : distanItems) {

deltaPowSum += Math.pow((distance - distanceU), 2);

}

distanceTheta = Math.sqrt(deltaPowSum);

// 剔除噪点可以采用:

// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。

double minDistance = distanceU - 0.5 * distanceTheta;

double maxDistance = distanceU + 0.5 * distanceTheta;

List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();

for (int i = distanItems.size() - 1; i >= 0; i--) {

Double distance = distanItems.get(i);

if (distance <= minDistance || distance >= maxDistance) {

willbeRemoveIdxs.add(i);

System.out.println("will be remove " + i);

}

}

for (int willbeRemoveIdx : willbeRemoveIdxs) {

samples.remove(willbeRemoveIdx);

}

return samples;

}

实际业务测试:

package com.zjanalyse.spark.maths;

import java.util.ArrayList;

import java.util.Comparator;

import java.util.List;

import org.ujmp.core.DenseMatrix;

import org.ujmp.core.Matrix;

public class LastSquare {

/**

* y=ax+b a(0,1] b[5,20] x[0,500] y>=5

*/

public static void main(String[] args) {

// y= 0.8d*x+15

// 当x不变动时,y对应有多个值;此时把y求均值。

List<Sample> samples = new ArrayList<Sample>();

samples.add(new Sample(0.8d * 11 + 15 + 1, 11d));

samples.add(new Sample(0.8d * 24 + 15 + 0.8, 24d));

samples.add(new Sample(0.8d * 33 + 15 + 0.7, 33d));

samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));

samples.add(new Sample(0.8d * 47 + 15 + 0.3, 47d));

samples.add(new Sample(0.8d * 60 + 15 + 0.4, 60d));

samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));

samples.add(new Sample(0.8d * 57 + 15 + 0.3, 57d));

samples.add(new Sample(0.8d * 70 + 60 + 0.3, 70d));

samples.add(new Sample(0.8d * 80 + 60 + 0.3, 80d));

samples.add(new Sample(0.8d * 40 + 30 + 0.3, 40d));

sortSample(samples);

System.out.println("原始样本数据");

for (Sample sample : samples) {

System.out.println(sample);

}

System.out.println("开始“所有点”通过“业务数据取值范围”剔除:");

// 按照业务过滤。。。

filterByBusiness(samples);

System.out.println("结束“所有点”通过“业务数据取值范围”剔除:");

for (Sample sample : samples) {

System.out.println(sample);

}

int sampleCount = samples.size();

int fetureCout = 2;

System.out.println("第一次拟合。。。");

Matrix theta = leastsequare(sampleCount, fetureCout, samples);

double wear_loss = theta.getAsDouble(0, 0);

double path_loss = theta.getAsDouble(1, 0);

System.out.println("wear loss " + wear_loss);

System.out.println("path loss " + path_loss);

System.out.println("开始“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");

samples = filterSample(wear_loss, path_loss, samples);

System.out.println("结束“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");

for (Sample sample : samples) {

System.out.println(sample);

}

System.out.println("第二次拟合。。。");

sampleCount = samples.size();

fetureCout = 2;

if (sampleCount >= 2) {

theta = leastsequare(sampleCount, fetureCout, samples);

wear_loss = theta.getAsDouble(0, 0);

path_loss = theta.getAsDouble(1, 0);

System.out.println("wear loss " + wear_loss);

System.out.println("path loss " + path_loss);

}

System.out.println("complete...");

}

/**

* 按照业务过滤有效值范围

*/

private static void filterByBusiness(List<Sample> samples) {

for (int i = 0; i < samples.size(); i++) {

double x = samples.get(i).getX();

double y = samples.get(i).getY();

if (x >= 500) {

System.out.println(x + " x值超出有效值范围[0,500)");

samples.remove(i);

i--;

}

// y= 0.8d*x+15

else if (y < 0 * x + 5 || y > 1 * x + 30) {

System.out.println(

y + " y值超出有效值范围[(0*x+5),(1*x+30)]其中x=" + x + ",也就是:[" + (0 * x + 5) + "," + (1 * x + 30) + ")");

samples.remove(i);

i--;

}

}

}

/**

* Description 点到直线的距离

*

* @param x1

* 点横坐标

* @param y1

* 点纵坐标

* @param A

* 直线方程一般式系数A

* @param B

* 直线方程一般式系数B

* @param C

* 直线方程一般式系数C

* @return 点到之间的距离

* @see 点0,1到之前y=x+0的距离 <br>

* double distance = getDistanceOfPerpendicular(0,0, -1, 1, 0);<br>

* System.out.println(distance);<br>

*/

private static double getDistanceOfPerpendicular(double x1, double y1, double A, double B, double C) {

double distance = Math.abs((A * x1 + B * y1 + C) / Math.sqrt(A * A + B * B));

return distance;

}

private static List<Sample> filterSample(double wear_loss, double path_loss, List<Sample> samples) {

// x={x1,x2,x3...xn}

// u=E(x) ---x的期望(均值)为 u

// 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))

// 6为x的标准差,标准差=sqrt(方差)

// 剔除噪点可以采用:

// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。

// 求出所有点距离第一次拟合结果的直线方程的距离

int sampleCount = samples.size();

List<Double> distanItems = new ArrayList<Double>();

// 计算出所有点距离该中心点的距离

for (int i = 0; i < samples.size(); i++) {

Sample sample = samples.get(i);

double distance = getDistanceOfPerpendicular(sample.getX(), sample.getY(), path_loss, -1, wear_loss);

distanItems.add(Math.sqrt(distance));

}

// 以下对根据距离(所有点距离中心点的距离)进行筛选

double sumDistan = 0d;

double distanceU = 0d;

for (Double distance : distanItems) {

sumDistan += distance;

}

distanceU = sumDistan / sampleCount;

double deltaPowSum = 0d;

double distanceTheta = 0d;

// sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))

for (Double distance : distanItems) {

deltaPowSum += Math.pow((distance - distanceU), 2);

}

distanceTheta = Math.sqrt(deltaPowSum);

// 剔除噪点可以采用:

// 若xi不属于(u-3*6,u+3*6),则认为它是噪点。

double minDistance = distanceU - 0.25 * distanceTheta;

double maxDistance = distanceU + 0.25 * distanceTheta;

List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();

for (int i = distanItems.size() - 1; i >= 0; i--) {

Double distance = distanItems.get(i);

if (distance <= minDistance || distance >= maxDistance) {

System.out.println(distance + " out of range [" + minDistance + "," + maxDistance + "]");

willbeRemoveIdxs.add(i);

} else {

System.out.println(distance);

}

}

for (int willbeRemoveIdx : willbeRemoveIdxs) {

Sample sample = samples.get(willbeRemoveIdx);

System.out.println("remove " + sample);

samples.remove(willbeRemoveIdx);

}

return samples;

}

/**

* 对采样点进行排序,按照x排序,升序排列

*

* @param samples

* 采样点集合

**/

private static void sortSample(List<Sample> samples) {

samples.sort(new Comparator<Sample>() {

public int compare(Sample o1, Sample o2) {

if (o1.getX() > o2.getX()) {

return 1;

} else if (o1.getX() <= o2.getX()) {

return -1;

}

return 0;

}

});

}

/**

* Description 采用最小二乘法多项式拟合方式,获取多项式的系数。

*

* @param sampleCount

* 采样点个数

* @param fetureCount

* 多项式的系数

* @param samples

* 采样点集合

**/

private static Matrix leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {

// 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵

Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);

for (int i = 0; i < samples.size(); i++) {

matrixX.setAsDouble(samples.get(i).getX(), i, 1);

}

// System.out.println(matrixX);

// System.out.println("--------------------------------------");

// 构件 2*2矩阵 存储X

Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);

for (int i = 0; i < samples.size(); i++) {

matrixY.setAsDouble(samples.get(i).getY(), i, 0);

}

// System.out.println(matrixY);

// 对X进行转置

Matrix matrixXTrans = matrixX.transpose();

// System.out.println(matrixXTrans);

// 乘积运算:x*转转置后x:matrixXTrans*matrixX

Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);

// System.out.println(matrixMtimes);

// System.out.println("--------------------------------------");

// 求逆

Matrix matrixMtimesInv = matrixMtimes.inv();

// System.out.println(matrixMtimesInv);

// x转置后结果*求逆结果

// System.out.println("--------------------------------------");

Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);

// System.out.println(matrixMtimesInvMtimes);

// System.out.println("--------------------------------------");

Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);

// System.out.println(theta);

return theta;

}

}

View Code

以上是 最小二乘法多项式拟合的Java实现 的全部内容, 来源链接: utcz.com/z/390259.html

回到顶部