Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						79a7cfe
	
1
								Parent(s):
							
							1b92896
								
Add precision parameter
Browse files- pysr/sr.py +28 -11
    	
        pysr/sr.py
    CHANGED
    
    | @@ -132,6 +132,7 @@ def pysr( | |
| 132 | 
             
                tournament_selection_p=1.0,
         | 
| 133 | 
             
                denoise=False,
         | 
| 134 | 
             
                Xresampled=None,
         | 
|  | |
| 135 | 
             
            ):
         | 
| 136 | 
             
                """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
         | 
| 137 | 
             
                Note: most default parameters have been tuned over several example
         | 
| @@ -250,6 +251,8 @@ def pysr( | |
| 250 | 
             
                :type tournament_selection_p: float
         | 
| 251 | 
             
                :param denoise: Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data.
         | 
| 252 | 
             
                :type denoise: bool
         | 
|  | |
|  | |
| 253 | 
             
                :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
         | 
| 254 | 
             
                :type: pd.DataFrame/list
         | 
| 255 | 
             
                """
         | 
| @@ -427,6 +430,7 @@ def pysr( | |
| 427 | 
             
                    tournament_selection_n=tournament_selection_n,
         | 
| 428 | 
             
                    tournament_selection_p=tournament_selection_p,
         | 
| 429 | 
             
                    denoise=denoise,
         | 
|  | |
| 430 | 
             
                )
         | 
| 431 |  | 
| 432 | 
             
                kwargs = {**_set_paths(tempdir), **kwargs}
         | 
| @@ -582,40 +586,53 @@ def _create_julia_files( | |
| 582 |  | 
| 583 |  | 
| 584 | 
             
            def _make_datasets_julia_str(
         | 
| 585 | 
            -
                X, | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 586 | 
             
            ):
         | 
| 587 | 
             
                def_datasets = """using DelimitedFiles"""
         | 
| 588 | 
            -
                 | 
|  | |
|  | |
|  | |
| 589 | 
             
                if multioutput:
         | 
| 590 | 
            -
                    np.savetxt(y_filename, y.astype( | 
| 591 | 
             
                else:
         | 
| 592 | 
            -
                    np.savetxt(y_filename, y.reshape(-1, 1).astype( | 
|  | |
| 593 | 
             
                if weights is not None:
         | 
| 594 | 
             
                    if multioutput:
         | 
| 595 | 
            -
                        np.savetxt(weights_filename, weights.astype( | 
| 596 | 
             
                    else:
         | 
| 597 | 
             
                        np.savetxt(
         | 
| 598 | 
             
                            weights_filename,
         | 
| 599 | 
            -
                            weights.reshape(-1, 1).astype( | 
| 600 | 
             
                            delimiter=",",
         | 
| 601 | 
             
                        )
         | 
|  | |
| 602 | 
             
                def_datasets += f"""
         | 
| 603 | 
            -
            X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',',  | 
| 604 |  | 
| 605 | 
             
                if multioutput:
         | 
| 606 | 
             
                    def_datasets += f"""
         | 
| 607 | 
            -
            y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',',  | 
| 608 | 
             
                else:
         | 
| 609 | 
             
                    def_datasets += f"""
         | 
| 610 | 
            -
            y = readdlm("{_escape_filename(y_filename)}", ',',  | 
| 611 |  | 
| 612 | 
             
                if weights is not None:
         | 
| 613 | 
             
                    if multioutput:
         | 
| 614 | 
             
                        def_datasets += f"""
         | 
| 615 | 
            -
            weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',',  | 
| 616 | 
             
                    else:
         | 
| 617 | 
             
                        def_datasets += f"""
         | 
| 618 | 
            -
            weights = readdlm("{_escape_filename(weights_filename)}", ',',  | 
| 619 | 
             
                return def_datasets
         | 
| 620 |  | 
| 621 |  | 
|  | |
| 132 | 
             
                tournament_selection_p=1.0,
         | 
| 133 | 
             
                denoise=False,
         | 
| 134 | 
             
                Xresampled=None,
         | 
| 135 | 
            +
                precision=32,
         | 
| 136 | 
             
            ):
         | 
| 137 | 
             
                """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
         | 
| 138 | 
             
                Note: most default parameters have been tuned over several example
         | 
|  | |
| 251 | 
             
                :type tournament_selection_p: float
         | 
| 252 | 
             
                :param denoise: Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data.
         | 
| 253 | 
             
                :type denoise: bool
         | 
| 254 | 
            +
                :param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
         | 
| 255 | 
            +
                :type precision: int
         | 
| 256 | 
             
                :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
         | 
| 257 | 
             
                :type: pd.DataFrame/list
         | 
| 258 | 
             
                """
         | 
|  | |
| 430 | 
             
                    tournament_selection_n=tournament_selection_n,
         | 
| 431 | 
             
                    tournament_selection_p=tournament_selection_p,
         | 
| 432 | 
             
                    denoise=denoise,
         | 
| 433 | 
            +
                    precision=precision,
         | 
| 434 | 
             
                )
         | 
| 435 |  | 
| 436 | 
             
                kwargs = {**_set_paths(tempdir), **kwargs}
         | 
|  | |
| 586 |  | 
| 587 |  | 
| 588 | 
             
            def _make_datasets_julia_str(
         | 
| 589 | 
            +
                X,
         | 
| 590 | 
            +
                X_filename,
         | 
| 591 | 
            +
                weights,
         | 
| 592 | 
            +
                weights_filename,
         | 
| 593 | 
            +
                y,
         | 
| 594 | 
            +
                y_filename,
         | 
| 595 | 
            +
                multioutput,
         | 
| 596 | 
            +
                precision,
         | 
| 597 | 
            +
                **kwargs,
         | 
| 598 | 
             
            ):
         | 
| 599 | 
             
                def_datasets = """using DelimitedFiles"""
         | 
| 600 | 
            +
                julia_dtype = {16: "Float16", 32: "Float32", 64: "Float64"}[precision]
         | 
| 601 | 
            +
                np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                np.savetxt(X_filename, X.astype(np_dtype), delimiter=",")
         | 
| 604 | 
             
                if multioutput:
         | 
| 605 | 
            +
                    np.savetxt(y_filename, y.astype(np_dtype), delimiter=",")
         | 
| 606 | 
             
                else:
         | 
| 607 | 
            +
                    np.savetxt(y_filename, y.reshape(-1, 1).astype(np_dtype), delimiter=",")
         | 
| 608 | 
            +
             | 
| 609 | 
             
                if weights is not None:
         | 
| 610 | 
             
                    if multioutput:
         | 
| 611 | 
            +
                        np.savetxt(weights_filename, weights.astype(np_dtype), delimiter=",")
         | 
| 612 | 
             
                    else:
         | 
| 613 | 
             
                        np.savetxt(
         | 
| 614 | 
             
                            weights_filename,
         | 
| 615 | 
            +
                            weights.reshape(-1, 1).astype(np_dtype),
         | 
| 616 | 
             
                            delimiter=",",
         | 
| 617 | 
             
                        )
         | 
| 618 | 
            +
             | 
| 619 | 
             
                def_datasets += f"""
         | 
| 620 | 
            +
            X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', {julia_dtype}, '\\n')))"""
         | 
| 621 |  | 
| 622 | 
             
                if multioutput:
         | 
| 623 | 
             
                    def_datasets += f"""
         | 
| 624 | 
            +
            y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', {julia_dtype}, '\\n')))"""
         | 
| 625 | 
             
                else:
         | 
| 626 | 
             
                    def_datasets += f"""
         | 
| 627 | 
            +
            y = readdlm("{_escape_filename(y_filename)}", ',', {julia_dtype}, '\\n')[:, 1]"""
         | 
| 628 |  | 
| 629 | 
             
                if weights is not None:
         | 
| 630 | 
             
                    if multioutput:
         | 
| 631 | 
             
                        def_datasets += f"""
         | 
| 632 | 
            +
            weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', {julia_dtype}, '\\n')))"""
         | 
| 633 | 
             
                    else:
         | 
| 634 | 
             
                        def_datasets += f"""
         | 
| 635 | 
            +
            weights = readdlm("{_escape_filename(weights_filename)}", ',', {julia_dtype}, '\\n')[:, 1]"""
         | 
| 636 | 
             
                return def_datasets
         | 
| 637 |  | 
| 638 |  |