Viewing file: csvs.py (10.28 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
""" Module for formatting output data into CSV files. """
from __future__ import annotations
from collections.abc import ( Hashable, Iterable, Iterator, Sequence, ) import csv as csvlib import os from typing import ( TYPE_CHECKING, Any, cast, )
import numpy as np
from pandas._libs import writers as libwriters from pandas._typing import SequenceNotStr from pandas.util._decorators import cache_readonly
from pandas.core.dtypes.generic import ( ABCDatetimeIndex, ABCIndex, ABCMultiIndex, ABCPeriodIndex, ) from pandas.core.dtypes.missing import notna
from pandas.core.indexes.api import Index
from pandas.io.common import get_handle
if TYPE_CHECKING: from pandas._typing import ( CompressionOptions, FilePath, FloatFormatType, IndexLabel, StorageOptions, WriteBuffer, npt, )
from pandas.io.formats.format import DataFrameFormatter
_DEFAULT_CHUNKSIZE_CELLS = 100_000
class CSVFormatter: cols: npt.NDArray[np.object_]
def __init__( self, formatter: DataFrameFormatter, path_or_buf: FilePath | WriteBuffer[str] | WriteBuffer[bytes] = "", sep: str = ",", cols: Sequence[Hashable] | None = None, index_label: IndexLabel | None = None, mode: str = "w", encoding: str | None = None, errors: str = "strict", compression: CompressionOptions = "infer", quoting: int | None = None, lineterminator: str | None = "\n", chunksize: int | None = None, quotechar: str | None = '"', date_format: str | None = None, doublequote: bool = True, escapechar: str | None = None, storage_options: StorageOptions | None = None, ) -> None: self.fmt = formatter
self.obj = self.fmt.frame
self.filepath_or_buffer = path_or_buf self.encoding = encoding self.compression: CompressionOptions = compression self.mode = mode self.storage_options = storage_options
self.sep = sep self.index_label = self._initialize_index_label(index_label) self.errors = errors self.quoting = quoting or csvlib.QUOTE_MINIMAL self.quotechar = self._initialize_quotechar(quotechar) self.doublequote = doublequote self.escapechar = escapechar self.lineterminator = lineterminator or os.linesep self.date_format = date_format self.cols = self._initialize_columns(cols) self.chunksize = self._initialize_chunksize(chunksize)
@property def na_rep(self) -> str: return self.fmt.na_rep
@property def float_format(self) -> FloatFormatType | None: return self.fmt.float_format
@property def decimal(self) -> str: return self.fmt.decimal
@property def header(self) -> bool | SequenceNotStr[str]: return self.fmt.header
@property def index(self) -> bool: return self.fmt.index
def _initialize_index_label(self, index_label: IndexLabel | None) -> IndexLabel: if index_label is not False: if index_label is None: return self._get_index_label_from_obj() elif not isinstance(index_label, (list, tuple, np.ndarray, ABCIndex)): # given a string for a DF with Index return [index_label] return index_label
def _get_index_label_from_obj(self) -> Sequence[Hashable]: if isinstance(self.obj.index, ABCMultiIndex): return self._get_index_label_multiindex() else: return self._get_index_label_flat()
def _get_index_label_multiindex(self) -> Sequence[Hashable]: return [name or "" for name in self.obj.index.names]
def _get_index_label_flat(self) -> Sequence[Hashable]: index_label = self.obj.index.name return [""] if index_label is None else [index_label]
def _initialize_quotechar(self, quotechar: str | None) -> str | None: if self.quoting != csvlib.QUOTE_NONE: # prevents crash in _csv return quotechar return None
@property def has_mi_columns(self) -> bool: return bool(isinstance(self.obj.columns, ABCMultiIndex))
def _initialize_columns( self, cols: Iterable[Hashable] | None ) -> npt.NDArray[np.object_]: # validate mi options if self.has_mi_columns: if cols is not None: msg = "cannot specify cols with a MultiIndex on the columns" raise TypeError(msg)
if cols is not None: if isinstance(cols, ABCIndex): cols = cols._get_values_for_csv(**self._number_format) else: cols = list(cols) self.obj = self.obj.loc[:, cols]
# update columns to include possible multiplicity of dupes # and make sure cols is just a list of labels new_cols = self.obj.columns return new_cols._get_values_for_csv(**self._number_format)
def _initialize_chunksize(self, chunksize: int | None) -> int: if chunksize is None: return (_DEFAULT_CHUNKSIZE_CELLS // (len(self.cols) or 1)) or 1 return int(chunksize)
@property def _number_format(self) -> dict[str, Any]: """Dictionary used for storing number formatting settings.""" return { "na_rep": self.na_rep, "float_format": self.float_format, "date_format": self.date_format, "quoting": self.quoting, "decimal": self.decimal, }
@cache_readonly def data_index(self) -> Index: data_index = self.obj.index if ( isinstance(data_index, (ABCDatetimeIndex, ABCPeriodIndex)) and self.date_format is not None ): data_index = Index( [x.strftime(self.date_format) if notna(x) else "" for x in data_index] ) elif isinstance(data_index, ABCMultiIndex): data_index = data_index.remove_unused_levels() return data_index
@property def nlevels(self) -> int: if self.index: return getattr(self.data_index, "nlevels", 1) else: return 0
@property def _has_aliases(self) -> bool: return isinstance(self.header, (tuple, list, np.ndarray, ABCIndex))
@property def _need_to_save_header(self) -> bool: return bool(self._has_aliases or self.header)
@property def write_cols(self) -> SequenceNotStr[Hashable]: if self._has_aliases: assert not isinstance(self.header, bool) if len(self.header) != len(self.cols): raise ValueError( f"Writing {len(self.cols)} cols but got {len(self.header)} aliases" ) return self.header else: # self.cols is an ndarray derived from Index._get_values_for_csv, # so its entries are strings, i.e. hashable return cast(SequenceNotStr[Hashable], self.cols)
@property def encoded_labels(self) -> list[Hashable]: encoded_labels: list[Hashable] = []
if self.index and self.index_label: assert isinstance(self.index_label, Sequence) encoded_labels = list(self.index_label)
if not self.has_mi_columns or self._has_aliases: encoded_labels += list(self.write_cols)
return encoded_labels
def save(self) -> None: """ Create the writer & save. """ # apply compression and byte/text conversion with get_handle( self.filepath_or_buffer, self.mode, encoding=self.encoding, errors=self.errors, compression=self.compression, storage_options=self.storage_options, ) as handles: # Note: self.encoding is irrelevant here self.writer = csvlib.writer( handles.handle, lineterminator=self.lineterminator, delimiter=self.sep, quoting=self.quoting, doublequote=self.doublequote, escapechar=self.escapechar, quotechar=self.quotechar, )
self._save()
def _save(self) -> None: if self._need_to_save_header: self._save_header() self._save_body()
def _save_header(self) -> None: if not self.has_mi_columns or self._has_aliases: self.writer.writerow(self.encoded_labels) else: for row in self._generate_multiindex_header_rows(): self.writer.writerow(row)
def _generate_multiindex_header_rows(self) -> Iterator[list[Hashable]]: columns = self.obj.columns for i in range(columns.nlevels): # we need at least 1 index column to write our col names col_line = [] if self.index: # name is the first column col_line.append(columns.names[i])
if isinstance(self.index_label, list) and len(self.index_label) > 1: col_line.extend([""] * (len(self.index_label) - 1))
col_line.extend(columns._get_level_values(i)) yield col_line
# Write out the index line if it's not empty. # Otherwise, we will print out an extraneous # blank line between the mi and the data rows. if self.encoded_labels and set(self.encoded_labels) != {""}: yield self.encoded_labels + [""] * len(columns)
def _save_body(self) -> None: nrows = len(self.data_index) chunks = (nrows // self.chunksize) + 1 for i in range(chunks): start_i = i * self.chunksize end_i = min(start_i + self.chunksize, nrows) if start_i >= end_i: break self._save_chunk(start_i, end_i)
def _save_chunk(self, start_i: int, end_i: int) -> None: # create the data for a chunk slicer = slice(start_i, end_i) df = self.obj.iloc[slicer]
res = df._get_values_for_csv(**self._number_format) data = list(res._iter_column_arrays())
ix = self.data_index[slicer]._get_values_for_csv(**self._number_format) libwriters.write_csv_rows( data, ix, self.nlevels, self.cols, self.writer, )
|