Pandasで詰まったところ
初めに
今月はオライリーから出ている”scikit-learn、Keras、TensorFlowによる実践機械学習”を使用して勉強していました。
演習問題であるタイタニックの生存予想をする前にPandasで戸惑った場面があったのでメモ程度に残しておきます。
詰まったところ
csvファイルから読み込んだデータのクリーニング最中に以下の関数を作成していました。
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.compose import ColumnTransformer
import pandas as pd
def make_datasets(data,is_test_data=False):
col_keys = {"Pclass", "Age","Sex","Fare"}
X = data.drop((set(data.columns.values) - col_keys), axis=1)
if (not is_test_data):
y = data["Survived"].copy()
pipe_line = ColumnTransformer([("num_pc", SimpleImputer(strategy="mean"),
["Pclass","Age","Fare"]),
("cat", OrdinalEncoder(), ["Sex"])])
X = pipe_line.fit_transform(X)
X = pd.DataFrame(X, columns=col_keys) # <==該当箇所
return (X, y) if (not is_test_data) else X
>>>
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 4 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Pclass 891 non-null float64
1 Fare 891 non-null float64
2 Sex 891 non-null float64
3 Age 891 non-null float64
dtypes: float64(4)
memory usage: 28.0 KB
上のコードの該当箇所のcolumnsでcolumnを指定しているのにも関わらず、何故かPclass,Fare,Sex,Ageの順でデータフレームが作成されています。
pipe_lineの順番が悪いのかと思い、col_keysの順番と同じにしてみても変化なし、col_keysをリストにしてみてもダメ、おまけにデバックするたびにcolumnの順番がランダムに変更されてしまう。
解決方法
pandasのdocumentを読んでみると、”columns : Index or array-like”って書いてあったのでとりあえず以下のように変更
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.compose import ColumnTransformer
def make_datasets(data,is_test_data=False):
col_keys = {"Pclass", "Age","Sex","Fare"}
X = data.drop((set(data.columns.values) - col_keys), axis=1)
if (not is_test_data):
y = data["Survived"].copy()
pipe_line = ColumnTransformer([("num_pc", SimpleImputer(strategy="mean"),
["Pclass","Age","Fare"]),
("cat", OrdinalEncoder(), ["Sex"])])
X = pipe_line.fit_transform(X)
X = pd.DataFrame(X, columns=["Pclass", "Age","Fare", "Sex"])
return (X, y) if (not is_test_data) else X
>>>
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 4 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Pclass 891 non-null float64
1 Age 891 non-null float64
2 Fare 891 non-null float64
3 Sex 891 non-null float64
dtypes: float64(4)
memory usage: 28.0 KB
とりあえずcol_keysと同じ並びにはなったけど、col_keysをリストにしてcolumns=col_keysにしてもダメだったのはなんでだろう?
コメント入力