77import matplotlib .pyplot as plt
88import math
99
10- class MeltingPointClusterAnalyzer ():
1110
11+ class MeltingPointClusterAnalyzer :
1212 def _get_clusters (self , points ):
1313 clustering = AgglomerativeClustering (n_clusters = 2 ).fit (points )
1414 cluster1 = points [np .argwhere (clustering .labels_ == 1 ).squeeze ()].T
1515 cluster2 = points [np .argwhere (clustering .labels_ == 0 ).squeeze ()].T
1616 return cluster1 , cluster2
1717
18- def plot_vol_vs_temp (self , ts , vs , plot_title = None ):
18+ def plot_vol_vs_temp (self , ts , vs , plot_title = None ):
1919 points = np .array (list (zip (ts , vs )))
2020 cluster1 , cluster2 = self ._get_clusters (points )
2121 plt .scatter (* cluster1 )
2222 plt .scatter (* cluster2 )
2323 plt .xlabel ("Temperature (K)" )
2424 plt .ylabel ("Volume (A^3)" )
2525 Tm = self .estimate_melting_temp (ts , vs )
26- plt .plot ([Tm , Tm ], [min (vs ), max (vs )], color = 'r' )
27-
26+ plt .plot ([Tm , Tm ], [min (vs ), max (vs )], color = "r" )
27+
2828 if plot_title is None :
2929 plt .title ("Volume vs Temperature by Clustering" )
3030 else :
3131 plt .title (plot_title )
32-
32+
3333 def estimate_melting_temp (self , temps , vols ):
3434 points = np .array (list (zip (temps , vols )))
3535 cluster1 , cluster2 = self ._get_clusters (points )
@@ -42,8 +42,8 @@ def estimate_melting_temp(self, temps, vols):
4242
4343 return np .mean ([max (solid_range ), min (liquid_range )])
4444
45- class MeltingPointSlopeAnalyzer ():
4645
46+ class MeltingPointSlopeAnalyzer :
4747 def split_dset (self , pts , split_idx ):
4848 return pts [0 :split_idx ], pts [split_idx :]
4949
@@ -56,7 +56,7 @@ def assess_splits(self, xs, ys):
5656 for idx in pt_idxs :
5757 _ , _ , _ , _ , total_err = self .get_split_fit (xs , ys , idx )
5858 errs .append (total_err )
59-
59+
6060 return list (zip (pt_idxs , errs ))
6161
6262 def get_linear_ys (self , m , b , xs ):
@@ -79,56 +79,61 @@ def plot_split(self, xs, ys, split_idx):
7979
8080 plt .scatter (rightxs , rightys )
8181 plt .plot (rightxs , right_fit_ys )
82-
82+
8383 def get_best_split (self , xs , ys ):
8484 split_errs = self .assess_splits (xs , ys )
8585 errs = [pt [1 ] for pt in split_errs ]
8686 idxs = [pt [0 ] for pt in split_errs ]
8787 best_split_idx = idxs [np .argmin (errs )]
8888 return best_split_idx
89-
89+
9090 def plot_vol_vs_temp (self , temps , vols ):
9191 split_idx = self .get_best_split (temps , vols )
9292 self .plot_split (temps , vols , split_idx )
9393 Tm = self .estimate_melting_temp (temps , vols )
9494 print (Tm )
95- plt .plot ([Tm , Tm ], [min (vols ), max (vols )], color = 'r' )
96-
95+ plt .plot ([Tm , Tm ], [min (vols ), max (vols )], color = "r" )
9796
9897 def estimate_melting_temp (self , temps , vols ):
9998 best_split_idx = self .get_best_split (temps , vols )
10099 return np .mean ([temps [best_split_idx ], temps [best_split_idx - 1 ]])
101100
102- class MeltingPointSlopeRMSEAnalyzer (MeltingPointSlopeAnalyzer ):
103101
102+ class MeltingPointSlopeRMSEAnalyzer (MeltingPointSlopeAnalyzer ):
104103 def get_split_fit (self , xs , ys , split_idx ):
105104 leftx , rightx = self .split_dset (xs , split_idx )
106105 lefty , righty = self .split_dset (ys , split_idx )
107-
106+
108107 lslope , lintercept , r_value , p_value , std_err = linregress (leftx , lefty )
109108 left_y_pred = lintercept + lslope * np .array (leftx )
110109 lefterr = mean_squared_error (y_true = lefty , y_pred = left_y_pred , squared = False )
111110
112111 rslope , rintercept , r_value , p_value , std_err = linregress (rightx , righty )
113112 right_y_pred = rintercept + rslope * np .array (rightx )
114113 righterr = mean_squared_error (y_true = righty , y_pred = right_y_pred , squared = False )
115-
116- combined_err = math .sqrt (lefterr ** 2 + righterr ** 2 )
114+
115+ combined_err = math .sqrt (lefterr ** 2 + righterr ** 2 )
117116 combined_err = lefterr + righterr
118117 return lslope , lintercept , rslope , rintercept , combined_err
119118
120- class MeltingPointSlopeStdErrAnalyzer (MeltingPointSlopeAnalyzer ):
121119
120+ class MeltingPointSlopeStdErrAnalyzer (MeltingPointSlopeAnalyzer ):
122121 def get_split_fit (self , xs , ys , split_idx ):
123122 leftx , rightx = self .split_dset (xs , split_idx )
124123 lefty , righty = self .split_dset (ys , split_idx )
125-
124+
126125 leftfit = linregress (leftx , lefty )
127126 lefterr = leftfit .stderr
128-
127+
129128 rightfit = linregress (rightx , righty )
130129 righterr = rightfit .stderr
131-
132- combined_err = math .sqrt (lefterr ** 2 + righterr ** 2 )
130+
131+ combined_err = math .sqrt (lefterr ** 2 + righterr ** 2 )
133132 combined_err = lefterr + righterr
134- return leftfit .slope , leftfit .intercept , rightfit .slope , rightfit .intercept , combined_err
133+ return (
134+ leftfit .slope ,
135+ leftfit .intercept ,
136+ rightfit .slope ,
137+ rightfit .intercept ,
138+ combined_err ,
139+ )
0 commit comments