Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 15 additions & 59 deletions pygam/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Core Classes"""

import copy
import numpy as np

from pygam.utils import flatten, round_to_n_decimal_places
Expand Down Expand Up @@ -90,24 +91,6 @@ class name


class Core:
"""
Creates an instance of the Core class.

comes loaded with useful methods

Parameters
----------
name : str, default: None
line_width : int, default: 70
number of characters to print on a line
line_offset : int, default: 3
number of characters to indent after the first line

Returns
-------
self
"""

def __init__(self, name=None, line_width=70, line_offset=3):
self._name = name
self._line_width = line_width
Expand All @@ -120,13 +103,11 @@ def __init__(self, name=None, line_width=70, line_offset=3):
self._include = []

def __str__(self):
"""__str__ method."""
if self._name is None:
return self.__repr__()
return self._name

def __repr__(self):
"""__repr__ method."""
name = self.__class__.__name__
return nice_repr(
name,
Expand All @@ -137,50 +118,25 @@ def __repr__(self):
args=None,
)

# ✅ FIXED: NOW INSIDE CLASS
def get_params(self, deep=False):
"""
Returns a dict of all of the object's user-facing parameters.

Parameters
----------
deep : boolean, default: False
when True, also gets non-user-facing parameters

Returns
-------
dict
"""
attrs = self.__dict__
attrs = self.__dict__.copy()

for attr in self._include:
attrs[attr] = getattr(self, attr)

if deep is True:
return attrs
return dict(
[
(k, v)
for k, v in list(attrs.items())
if (k[0] != "_") and (k[-1] != "_") and (k not in self._exclude)
]
)
if deep:
return copy.deepcopy(attrs)

params = {
k: v
for k, v in attrs.items()
if (k[0] != "_") and (k[-1] != "_") and (k not in self._exclude)
}

return copy.deepcopy(params)

def set_params(self, deep=False, force=False, **parameters):
"""
Sets an object's parameters.

Parameters
----------
deep : boolean, default: False
when True, also sets non-user-facing parameters
force : boolean, default: False
when True, also sets parameters that the object does not already
have
**parameters : parameters to set

Returns
-------
self
"""
param_names = self.get_params(deep=deep).keys()
for parameter, value in parameters.items():
if (
Expand All @@ -189,4 +145,4 @@ def set_params(self, deep=False, force=False, **parameters):
or (hasattr(self, parameter) and parameter == parameter.strip("_"))
):
setattr(self, parameter, value)
return self
return self
13 changes: 13 additions & 0 deletions pygam/tests/test_get_params_mutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np
from pygam import LinearGAM, s

def test_get_params_does_not_mutate_model():
X = np.random.rand(50, 1)
y = np.random.rand(50)

gam = LinearGAM(s(0)).fit(X, y)

params = gam.get_params()
params["terms"][0].lam = 999

assert gam.terms[0].lam != 999
Loading