1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
| # split_csv.py
import argparse
import csv
import gzip
import math
import os
import sys
def open_input(path):
if path.lower().endswith(".gz"):
return gzip.open(path, "rt", newline="", encoding="utf-8-sig")
return open(path, "r", newline="", encoding="utf-8-sig")
def open_output(path):
return open(path, "w", newline="", encoding="utf-8-sig")
def count_data_rows(input_path):
with open_input(input_path) as f:
reader = csv.reader(f)
header = next(reader, None)
if header is None:
return 0
n = 0
for _ in reader:
n += 1
return n
def split_by_rows(input_path, out_dir, base_name, rows_per_file, force):
if rows_per_file <= 0:
raise ValueError("--rows-per-file 必须是正整数")
os.makedirs(out_dir, exist_ok=True)
with open_input(input_path) as f:
reader = csv.reader(f)
header = next(reader, None)
if header is None:
raise ValueError("输入 CSV 为空,无法切割")
part_index = 0
current_out = None
writer = None
row_in_part = 0
total_rows = 0
def start_new_part():
nonlocal part_index, current_out, writer, row_in_part
if current_out is not None:
current_out.close()
part_index += 1
out_path = os.path.join(out_dir, f"{base_name}.part{part_index:04d}.csv")
if os.path.exists(out_path) and not force:
raise FileExistsError(f"输出文件已存在(如需覆盖请加 --force):{out_path}")
current_out = open_output(out_path)
writer = csv.writer(current_out)
writer.writerow(header)
row_in_part = 0
return out_path
out_path = start_new_part()
out_paths = [out_path]
for row in reader:
if row_in_part >= rows_per_file:
out_path = start_new_part()
out_paths.append(out_path)
writer.writerow(row)
row_in_part += 1
total_rows += 1
if current_out is not None:
current_out.close()
return total_rows, out_paths
def split_by_parts(input_path, out_dir, base_name, parts, force):
if parts <= 0:
raise ValueError("--parts 必须是正整数")
total_rows = count_data_rows(input_path)
if total_rows == 0:
raise ValueError("输入 CSV 没有数据行,无法切割")
rows_per_file = math.ceil(total_rows / parts)
total_rows_written, out_paths = split_by_rows(
input_path=input_path,
out_dir=out_dir,
base_name=base_name,
rows_per_file=rows_per_file,
force=force,
)
return total_rows_written, out_paths, rows_per_file
def main():
try:
csv.field_size_limit(sys.maxsize)
except (OverflowError, ValueError):
csv.field_size_limit(1024 * 1024 * 1024)
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--out-dir", required=True)
parser.add_argument("--base-name", default=None)
parser.add_argument("--rows-per-file", type=int, default=None)
parser.add_argument("--parts", type=int, default=None)
parser.add_argument("--max-files", type=int, default=2000)
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
if args.rows_per_file is None and args.parts is None:
raise ValueError("必须指定 --rows-per-file 或 --parts")
if args.rows_per_file is not None and args.parts is not None:
raise ValueError("--rows-per-file 与 --parts 只能二选一")
base_name = args.base_name
if not base_name:
base_name = os.path.splitext(os.path.basename(args.input))[0]
if base_name.lower().endswith(".csv"):
base_name = base_name[:-4]
if args.rows_per_file is not None:
if args.max_files is not None and args.max_files > 0:
estimated_files = math.ceil(count_data_rows(args.input) / args.rows_per_file) if args.rows_per_file > 0 else 0
if estimated_files > args.max_files:
raise ValueError(
f"预计会生成 {estimated_files} 个文件,超过 --max-files={args.max_files}。"
f"请增大 --rows-per-file,或提高 --max-files,或使用 --parts。"
)
total_rows, out_paths = split_by_rows(
input_path=args.input,
out_dir=args.out_dir,
base_name=base_name,
rows_per_file=args.rows_per_file,
force=args.force,
)
print(f"完成:共写入 {total_rows} 行数据(不含表头)")
print(f"输出文件数:{len(out_paths)}")
print(f"输出目录:{args.out_dir}")
return
total_rows, out_paths, rows_per_file = split_by_parts(
input_path=args.input,
out_dir=args.out_dir,
base_name=base_name,
parts=args.parts,
force=args.force,
)
print(f"完成:共写入 {total_rows} 行数据(不含表头)")
print(f"输出文件数:{len(out_paths)}")
print(f"每个文件最大数据行:{rows_per_file}")
print(f"输出目录:{args.out_dir}")
if __name__ == "__main__":
main()
|