ClickHouse_ClickHouse按本地表推数组件原理与代码实战

ClickHouse_ClickHouse按本地表推数组件原理与代码实战

1.组件流程

1.1 集群检查

元数据补全

根据hive表schema信息,字段类型自动映射成CK表类型,支持自定义分区、表字段函数生成新的DataFrame,使用on Cluster方式创建CK本地表和分布式表。

集群负载

调用集群普罗米修斯监控接口或者直接查询集群相关系统表,获取集群中是否存在cpu低于80%、内存低于70%、负载低于峰值60%副本集合,如存在,则按资源等级动态降低并发,直至暂停推数,以保障推数期间ck集群稳定使用。

查询system.query_log表可以获取当前正在执行的sql以及memory_usage内存使用情况,查询system.asynchronous_metrics读取cpu负载情况。

删除数据检查

要推数的分区在写数之前需要先确保数据已清空,删除后通过count计算分布式表查看数据是否已删除完成。

元数据更新

CK更新本地表,根据用户新加字段类型、修改字段类型、删除字段类型,执行相应DDL SQL on cluster,完成更新,分布式表自动删除并映射新的本地表创建分布式表。

1.2 数据传输

数据切分

挑选离散度较高的字段作为分片字段,常用skuId、用户pin等,在sparkRDD中将数据按照分片字段进行repartition,每个partition向同一个节点上的本地表中写入数据,这样即保证了所有节点的数据是均衡的,同时减少网络开销,提升数据写入速度。

并发推数

必要情况下可以将数据再进行细分,开启多副本同时写数。

异常处理

在hive中正确的数据,ck中可能由于数据类型等原因无法正确导入,捕获异常数据的同时,打印出来并进行计数,上游优化数据或者舍弃ck无法导入的数据,Counter值,在最后的验证环节提供数据支撑,保证源数据量=ck中数据量+异常数据量。

1.3 数据验证

数据量检查

推入ck的数据量 = 源数据量 - 异常数据量

指标合理值检查

选择需要验证的指标,uv、pv等指标除双11、618期间,增长一般不超过大促峰值,超过阈值告警,上游需检查数据是否正确。

2.代码实战

运行主类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
package com.jd.clickhouse

import com.jd.clickhouse.client.NormalClickHouseClient
import com.jd.clickhouse.common.{Configuration, MyDateClass, Utils}
import com.jd.clickhouse.ext.ClickHouseProperties
import com.jd.clickhouse.spark.{CkSqlUtil, CkXmlUtil, CommonUtils, UsersFunctions}
import org.apache.spark.SparkContext
import org.apache.spark.sql._

import scala.collection.mutable.ListBuffer
import scala.util.control.Breaks._


object Spark2ClickHouse {

def main(args: Array[String]): Unit = {

val sparkSession = SparkSession.builder.appName("hdfs 2 clickhouse through spark").enableHiveSupport.getOrCreate
val props = ClickHouseProperties(Configuration.props)
//udf 注册
UsersFunctions().registerBsFunction(sparkSession)
//通过main的入参传入 数据源hdfs上的hive数据库与表名,目标数据库clickhouse的数据库,目标数据库中的表clickhouse_tableName
val hive_database = args(0)//app
val hive_table = args(1)//app_zh_industry_catg_brand_shop
val clickhouse_database = args(2)//bc_online
val clickhouse_tableName = args(3)//ck_industry_catg_brand_shop
val shardColumn = args(4).toString//dateTime
//如果参数的长度为大于3 返回昨天的时间日期
var sdate = if (args(5) != "--") args(5) else Utils.yesterday
//如果参数的长度小于等于1,返回昨天的时间日期否则返回传入的参数
var edate = if (args(6) != "--") args(6) else Utils.yesterday
val xml = if (args.length >= 8) args(7) else Nil
val clusterName: String = if (args.length >= 9) args(8) else null//LF0_CK_Pub_18
val clientHosts: String = if (args.length >= 10) args(9) else null//ckpub18.olap.jd.com
val clusterUser: String = if (args.length >= 11) args(10) else props.getUser//ads
val clusterPass: String = if (args.length >= 12) args(11) else props.getPassword//1qaz^RFV
val executorNum = if (args.length >= 13) args(12).toInt else 10//30
val deleteFlag = if (args.length >= 14) args(13).toInt else 1//1
val shardGroupNum: Int = if (args.length >= 15) args(14).toInt else 1//1
val hiveSql = if (args.length >= 16) args(15).toString else null
// "
// select
// dt as dateTime,
// item_first_cate_cd as firstIndId,
// item_second_cate_cd as secondIndId,
// item_third_cate_cd as thirdIndId,
// trade_type as shopType,
// shop_id as shopId,
// brand_code as sonBrandId
// from app.app_zh_industry_catg_brand_shop
// where dt>='2022-11-03' and dt<='2022-11-03'
// "
val primaryKey = if (args.length >= 17) args(16).toString else null //dateTime,secondIndId,shopType,thirdIndId,firstIndId
val partitionColumns = if (args.length >= 18) args(17).toString else null //dateTime
val clusterModeCurrency: Int = if (args.length >= 19) args(18).toInt else 0 //1
val clusterInfo = if (args.length >= 20) props.info.setProperty("hostTable", args(19).toString) else props.info.setProperty("hostTable", "clusters_host")
val port: Int = if (args.length >= 21) args(20).split(",")(0).toInt else 8123
val storagePolicy = if (args.length >= 22) args(21) else "jdob_ha"
val nullCheck = if (args.length >= 23) args(22) else "on"
val failedTryAgain: Int = if (args.length >= 24) args(23).toInt else 3
val linkTableHash = if (args.length >= 25) args(24) else "on"
var portMap = Map[String, Int]();
if (args.length >= 21) {
val portInfo = args(20).split(",")
for (i <- 1 to portInfo.length - 1) {
portMap += (portInfo(i).split("-")(0).toString -> portInfo(i).split("-")(1).toInt)
}
}
val versionFlag = if (args.length >= 26) args(25).toInt else 0
var update_partitions = if (args.length >= 27) args(26) else ""

val ckTmpTable = clickhouse_tableName + "_tmp"
//获取所有更新分区
val dropParamList = new ListBuffer[String]
val MD = new MyDateClass(Utils.yesterday)
var start_date = sdate

//判断是否双分区
if (update_partitions != "" && partitionColumns.split(",").length>1) {
//双分区
if(update_partitions.contains("(")){
update_partitions = update_partitions.replaceAll("\\),\\(", ")/u0019(")
val params = update_partitions.split("/u0019")
for (param <- params) {
dropParamList.append(param)
}
sdate = update_partitions
edate = update_partitions
}else{//自适配周期内双分区
var partition:StringBuilder = new StringBuilder();
while (start_date <= edate) {
partition.append("('"+start_date+s"','$update_partitions')"+"/u0019")
dropParamList.append("('"+start_date+s"','$update_partitions')")
start_date = MD.addNDays(start_date, 1)
}
sdate = partition.toString()
edate = partition.toString()
}
}else if(sdate != edate && sdate.length > 7) {
while (start_date <= edate) {
dropParamList.append(start_date)
start_date = MD.addNDays(start_date, 1)
}
} else {
dropParamList.append(sdate)
}


println(s"hive的源数据库:$hive_database")
println(s"hive的源数据库中的表:$hive_table")
println(s"ClickHouse中的数据库:$clickhouse_database")
println(s"ClickHouse中的数据库中的表:$clickhouse_tableName")
println(s"shardColumn:$shardColumn")
println(s"sdate:$sdate")
println(s"edate:$edate")
println(s"clusterName:$clusterName")
println(s"clientHosts:$clientHosts")
println(s"executorNum:$executorNum")
println(s"deleteFlag:$deleteFlag")
println(s"shard组数shardGroupNum:$shardGroupNum")
println(s"hivesql:$hiveSql")
println(s"primaryKey:$primaryKey")
println(s"partitionColumns:$partitionColumns")
println(s"clusterInfo:" + props.info.getProperty("hostTable"))
println(s"clusterModeCurrency:$clusterModeCurrency")
println(s"port:$port")
println(s"nullCheck:$nullCheck")
println(s"failedTryAgain:$failedTryAgain")
println(s"linkTableHash:$linkTableHash")
println("当前timeList: " + dropParamList)
println("当前versionFlag: " + versionFlag)
println(s"update_partitions:$update_partitions")


//ClickHouseProperties 相关值设置
val spark = sparkSession
val sc = spark.sparkContext
props.setHost(clientHosts)
props.setPort(port)
props.setUser(clusterUser)
props.setPassword(clusterPass)
props.setDatabase(clickhouse_database)
//implicit val client: NormalClickHouseClient = NormalClickHouseClient()(sc.broadcast(props))
//import com.jd.clickhouse.spark.ClickHouseSparkExt._


//hive表总条数默认值为0
implicit val commonUtils = new CommonUtils()
var res = dataDeleteAndSave(spark, sc, xml.toString, hive_database, hive_table, sdate, edate, clusterName, clientHosts, clickhouse_database, clickhouse_tableName, props, deleteFlag, shardColumn, clusterUser, clusterPass, shardGroupNum, executorNum, clusterModeCurrency, partitionColumns, hiveSql, portMap, primaryKey, storagePolicy, port, linkTableHash, versionFlag, dropParamList)
var ckPartition: String = res._1
var totalLoadingNum: Long = res._2
//判断读/过滤/映射hive数据是否通过xml配置
var insertNum: Long = 0
import scala.sys.process._
//insertNum查询CK中实际写入条数
println("延迟5s进行校验,waiting........................")
Thread.sleep(5000)
insertNum = CommonUtils().selectTotalNumByDay(clickhouse_database, clickhouse_tableName, ckPartition, sdate, edate, clientHosts, sc, props)
println("+++++++++++++++++++++++++clickhouse实际写入条数=" + insertNum)
//校验当hive表中的数据为0时的报警
if (totalLoadingNum == 0 && nullCheck != "off") {
val commond = Seq("curl", "http://signal-api.jd.local/sendTimeline", "-d", s"groups=黄金眼es推数监控&title=Ch Offline Loading Warning&msg= $clickhouse_database.$clickhouse_tableName loadingNum=0 &source=ge_es推数条数校验@af707c99")
commond !!
}
//校验hive条数与ck中写入条数是否一致
if (totalLoadingNum != insertNum) {
breakable {
for (tryNUm <- 1 to failedTryAgain) {
println(s"Loadingt条数不一致,开启第'$tryNUm'次重试........................")
var forRes = dataDeleteAndSave(spark, sc, xml.toString, hive_database, hive_table, sdate, edate, clusterName, clientHosts, clickhouse_database, clickhouse_tableName, props, deleteFlag, shardColumn, clusterUser, clusterPass, shardGroupNum, executorNum, clusterModeCurrency, partitionColumns, hiveSql, portMap, primaryKey, storagePolicy, port, linkTableHash, versionFlag, dropParamList)
ckPartition = forRes._1
totalLoadingNum = forRes._2
Thread.sleep(10000)
insertNum = CommonUtils().selectTotalNumByDay(clickhouse_database, clickhouse_tableName, ckPartition, sdate, edate, clientHosts, sc, props)
if (totalLoadingNum == insertNum) {
break()
}
}
}
if (totalLoadingNum != insertNum) {
//val commond = Seq("curl", "http://jenkins.jd.com/job/send_ump_alarm/buildWithParameters?token=123456", "-d", s"ALARM_INFO=[JD] Warning CK表loading数据条数与hive表数据条数不一致,请及时关注! $hive_database.$hive_table totalNum=$totalLoadingNum $clickhouse_database.$clickhouse_tableName loadingNum=$insertNum &UMP_KEY=ppzh.spark.data.monitor")
val commond = Seq("curl", "http://signal-api.jd.local/sendTimeline", "-d", s"groups=黄金眼es推数监控&title=Ch Offline Loading Warning&msg= CH表Loading数据条数与Hive表数据条数不一致(已重试 $failedTryAgain 次),请及时关注!$hive_database.$hive_table totalNum=$totalLoadingNum $clickhouse_database.$clickhouse_tableName loadingNum=$insertNum &source=ge_es推数条数校验@af707c99")
commond !!

throw new RuntimeException("数据loading条数不一致")
}
}

sparkSession.stop()

}

def dataDeleteAndSave(spark: SparkSession, sc: SparkContext, xml: String, hive_database: String, hive_table: String, sdate: String, edate: String, clusterName: String, clientHosts: String, clickhouse_database: String, clickhouse_tableName: String, props: ClickHouseProperties, deleteFlag: Int, shardColumn: String, clusterUser: String, clusterPass: String, shardGroupNum: Int, executorNum: Int, clusterModeCurrency: Int, partitionColumns: String, hiveSql: String, portMap: Map[String, Int], primaryKey: String, storagePolicy: String, port: Int, linkTableHash: String, versionFlag: Int, dropParamList: ListBuffer[String]): (String, Long) = {
//读取hive 表,过滤dt分区,返回类型DataFrame
implicit val client: NormalClickHouseClient = NormalClickHouseClient()(sc.broadcast(props))
var dataFrame: Dataset[Row] = null
implicit val commonUtils = new CommonUtils()
var result: Map[String, Long] = Map()
var ckPartition: String = null;
var hiveNum: Long = 0L;
import com.jd.clickhouse.spark.ClickHouseSparkExt._
if (!xml.equals("false")) {
val xmlElem = scala.xml.XML.loadString(xml)
implicit val ckXmlUtil: CkXmlUtil = CkXmlUtil(xmlElem)
//解析xml字符串获取hive与CK的映射
var reflect: Map[String, Seq[String]] = null;
var Conditions: Seq[String] = null;
reflect = ckXmlUtil.getReflect()
//从XML获取过滤条件
Conditions = ckXmlUtil.getFilterCondition()

dataFrame = spark.table(s"$hive_database.$hive_table").filter(s"dt<='$edate'").filter(s"dt>='$sdate'")
if (Conditions.size > 0) {
for (value <- Conditions) {
dataFrame = dataFrame.filter(s"$value")
}
}
hiveNum = dataFrame.count()
println("当前从hive获得数据条数为: " + hiveNum)
dataFrame = dataFrame.select(reflect.keys.toSeq.map(key => new Column(key).as(reflect(key)(0))): _*)
dataFrame.printSchema()
val schema = dataFrame.schema
ckXmlUtil.createCkTablesByXml(clusterName, clientHosts, schema, clickhouse_database, clickhouse_tableName, xmlElem, reflect, sc, props)
//get partitonColumn
var ckPartition: String = ckXmlUtil.getPartitionColumns()
//before insert data delete already exists data
commonUtils.deleteDataByPartition(deleteFlag, clickhouse_database, clickhouse_tableName, ckPartition, sdate, edate, clientHosts, clusterName, sc, props)
//insert into clickhouse
try {
result = dataFrame.saveToClickHouseBalancedByXml(clickhouse_database, clickhouse_tableName, clusterName, shardColumn, clientHosts, clusterUser, clusterPass, shardGroupNum, executorNum = executorNum, reflect = reflect, linkTableHash = linkTableHash)
} catch {
case e: Exception => {
e.printStackTrace();
throw e
}
}
} else if (clusterModeCurrency > 0) {
ckPartition = partitionColumns//dateTime
//直接通过hivesql来进行数据操作
implicit val ckSqlUtil: CkSqlUtil = new CkSqlUtil()
dataFrame = spark.sql(hiveSql)

//判断是否是多实例部署ck集群
val isMulity = commonUtils.isMultiInstance(clientHosts, sc, props, clusterName)

ckSqlUtil.createCkTablesBySql(clusterName, clientHosts, dataFrame.schema, clickhouse_database, clickhouse_tableName, primaryKey, partitionColumns, sc, props, portMap, storagePolicy, versionFlag)
//before insert data delete already exists data
if (isMulity) {
props.setPort(port);
if (versionFlag == 1) {
dropParamList.map(partitions => {
commonUtils.deleteDataByPartitionMulity(deleteFlag, clickhouse_database, clickhouse_tableName + "_tmp", partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props, port, portMap)
})
} else {
dropParamList.map(partitions => {
commonUtils.deleteDataByPartitionMulity(deleteFlag, clickhouse_database, clickhouse_tableName, partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props, port, portMap)
})
}
} else {
props.setPort(port);
if (versionFlag == 1) {
dropParamList.map(partitions => {
commonUtils.deleteDataByPartition(deleteFlag, clickhouse_database, clickhouse_tableName + "_tmp", partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props)
})
} else {
props.setPort(port);
dropParamList.map(partitions => {
commonUtils.deleteDataByPartition(deleteFlag, clickhouse_database, clickhouse_tableName, partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props)
})
}

}

try {
//判断是否是多实例补数的集群写入时带端口
if (isMulity) {
if (versionFlag == 1) {
result = dataFrame.saveToClickHouseBalancedBySqlMulityCluster(clickhouse_database, clickhouse_tableName + "_tmp", clusterName, shardColumn, clientHosts, clusterUser, clusterPass, shardGroupNum, portMap, clusterModeCurrency = clusterModeCurrency, linkTableHash = linkTableHash)
val totalLoadingNum: Long = result.valuesIterator.sum
Thread.sleep(10000)
val insertNum = CommonUtils().selectTotalNumByDay(clickhouse_database, clickhouse_tableName + "_tmp", ckPartition, sdate, edate, clientHosts, sc, props)
if (totalLoadingNum == insertNum) {
dropParamList.map(partitions => {
commonUtils.deleteDataByPartitionMulity(deleteFlag, clickhouse_database, clickhouse_tableName, partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props, port, portMap)
println("当前同步的日期为: " + partitions)
commonUtils.attachDataByPartitionMulity(versionFlag, clickhouse_database, clickhouse_tableName, clickhouse_tableName + "_tmp", partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props, port, portMap)
commonUtils.deleteDataByPartitionMulity(deleteFlag, clickhouse_database, clickhouse_tableName + "_tmp", partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props, port, portMap)
})
} else {
throw new RuntimeException(s"数据loading条数不一致totalLoadingNum=${totalLoadingNum} clickhouse tmp表实际查询条数=${insertNum}")
}
} else {
result = dataFrame.saveToClickHouseBalancedBySqlMulityCluster(clickhouse_database, clickhouse_tableName, clusterName, shardColumn, clientHosts, clusterUser, clusterPass, shardGroupNum, portMap, clusterModeCurrency = clusterModeCurrency, linkTableHash = linkTableHash)
}
} else {
if (versionFlag == 1) {
result = dataFrame.saveToClickHouseBalancedBySqlCluster(clickhouse_database, clickhouse_tableName + "_tmp", clusterName, shardColumn, clientHosts, clusterUser, clusterPass, shardGroupNum, clusterModeCurrency = clusterModeCurrency, linkTableHash = linkTableHash)
val totalLoadingNum: Long = result.valuesIterator.sum
Thread.sleep(10000)
val insertNum = CommonUtils().selectTotalNumByDay(clickhouse_database, clickhouse_tableName + "_tmp", ckPartition, sdate, edate, clientHosts, sc, props)
if (totalLoadingNum == insertNum) {
dropParamList.map(partitions => {
commonUtils.deleteDataByPartition(deleteFlag, clickhouse_database, clickhouse_tableName, partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props)
println(s"[当前同步的日期分区为{:${partitions}]}")
commonUtils.attachDataByPartition(versionFlag, clickhouse_database, clickhouse_tableName, clickhouse_tableName + "_tmp", partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props)
commonUtils.deleteDataByPartition(deleteFlag, clickhouse_database, clickhouse_tableName + "_tmp", partitionColumns, partitions, partitions, clientHosts, clusterName, sc, props)
})
} else {
throw new RuntimeException(s"数据loading条数不一致totalLoadingNum=${totalLoadingNum} clickhouse tmp表实际查询条数=${insertNum}")
}
} else {
result = dataFrame.saveToClickHouseBalancedBySqlCluster(clickhouse_database, clickhouse_tableName, clusterName, shardColumn, clientHosts, clusterUser, clusterPass, shardGroupNum, clusterModeCurrency = clusterModeCurrency, linkTableHash = linkTableHash)
}
}

} catch {
case e: Exception => {
e.printStackTrace();
throw e
}
}
} else {
ckPartition = partitionColumns
//直接通过hivesql来进行数据操作
implicit val ckSqlUtil: CkSqlUtil = new CkSqlUtil()
dataFrame = spark.sql(hiveSql)
dataFrame.printSchema()
hiveNum = dataFrame.count()
println("当前从hive获得数据条数为: " + hiveNum)
val schema = dataFrame.schema
ckSqlUtil.createCkTablesBySql(clusterName, clientHosts, schema, clickhouse_database, clickhouse_tableName, primaryKey, partitionColumns, sc, props, portMap, storagePolicy, versionFlag)
//before insert data delete already exists data
commonUtils.deleteDataByPartition(deleteFlag, clickhouse_database, clickhouse_tableName, partitionColumns, sdate, edate, clientHosts, clusterName, sc, props)
try {
result = dataFrame.saveToClickHouseBalancedBySql(clickhouse_database, clickhouse_tableName, clusterName, shardColumn, clientHosts, clusterUser, clusterPass, shardGroupNum, executorNum = executorNum, linkTableHash = linkTableHash)
} catch {
case e: Exception => {
e.printStackTrace();
throw e
}
}
}
//totalLoadNum spark执行总条数等同于hiveNum
var totalLoadNum: Long = 0
result.foreach { case (m, n) => {
totalLoadNum += n
println("host=" + m + "\n" + "loading: totalNum=" + n)
}
}
println("totalLoadNum:" + totalLoadNum)
(ckPartition, totalLoadNum);
}
}

写数RDD逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
package com.jd.clickhouse.spark

import java.io.{IOException, PrintStream}
import java.net.{InetSocketAddress, Socket}
import java.sql.{PreparedStatement, ResultSet, Statement}
import java.text.SimpleDateFormat
import java.util
import java.util.HashMap

import com.jd.clickhouse.client.ClickHouseClient
import com.jd.clickhouse.common.MergeTreeType
import com.jd.clickhouse.common.MergeTreeType.MergeTreeType
import com.jd.clickhouse.common.Utils._
import org.apache.spark.HashPartitioner
import org.apache.spark.sql.{Column, Encoders}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
//import com.sun.prism.impl.Disposer.Target
import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec.VAR
import org.apache.spark.sql.hive.client
import org.apache.spark.sql.{DataFrame, Row, hive}
import org.apache.spark.sql.types.{DataType, StructField, StructType}

object ClickHouseSparkExt {
implicit def extraOperations(df: DataFrame) = DataFrameExt(df)
}

case class DataFrameExt(df: DataFrame) extends Serializable {
//时间日期格式化
val sdf = new SimpleDateFormat("yyyy-mm-dd")

def saveToClickHouseBalanced[T <: ClickHouseClient](dbName: String, tableName: String,
clusterName: String,
batchSize: Int = 100000, executorNum: Int = 10, reflect: Map[String, Seq[String]])
(implicit client: T): Map[String, Long] = {
val schema = df.schema
import df.sparkSession.implicits._
// following code is going to be run on executors
val insertResults = df.repartition(executorNum).mapPartitions((partition: Iterator[org.apache.spark.sql.Row]) => {
//返回一条插入的sql语句
val insertSql = client.generateInsertStatement(schema, dbName, tableName)
var statements: Seq[PreparedStatement] = Seq()
var total = 0
val connections = client.getBalancedConnection(clusterName, dbName)
for (conn <- connections) {
statements = statements :+ conn.prepareStatement(insertSql)
total += 1
}

var totalInsert: Long = 0
var counter = 0
var batchNum = 0

var mapper: Map[String, Seq[String]] = Map()
reflect.map(f => mapper += (f._2(0) -> Seq(f._2(1),f._2(2))))
for (row <- partition) {
counter += 1
val statement = statements(batchNum % total)
schema.foreach { f =>
val fieldName = f.name
val originalType = f.dataType
val targetType = mapper(f.name)(0)
val splitType = mapper(f.name)(1)
val fieldIdx = row.fieldIndex(fieldName)
var fieldVal = row.get(fieldIdx)
try {
if (fieldVal != null) {
fieldVal = change(originalType, targetType,splitType, fieldIdx, row)
statement.setObject(fieldIdx + 1, fieldVal)
} else {
val defVal = client.defaultNullValue(client.xml2SparkType(mapper(f.name)(0)), fieldVal)
statement.setObject(fieldIdx + 1, defVal)
}
} catch {
case e: Exception => {
throw new RuntimeException("Hive字段:" + fieldName + ":" + fieldVal + ":" + f.dataType + "转化为" + mapper(f.name) + "类型异常")
}
}
}

statement.addBatch()

if (counter >= batchSize) {
batchNum += 1
val r = statement.executeBatch()
totalInsert += r.sum
counter = 0
}
}

if (counter > 0) {
val statement: Statement = statements(batchNum % total)
val r = statement.executeBatch()
totalInsert += r.sum
counter = 0
}

for (conn <- connections) {
conn.close()
}

// return: Seq((host, insertCount))
List((client.getTargetHost, totalInsert)).toIterator

}).persist

println("rdd_count:" + insertResults.count())

val results = insertResults.collect()
insertResults.unpersist()

// aggregate insert results by hosts
results.groupBy(_._1).map(x => (x._1, x._2.map(_._2).sum))
}


def saveToClickHouseBalancedBySql[T <: ClickHouseClient](dbName: String, tableName: String,
clusterName: String, shardColumn: String,
targetHost: String, clusterUser: String, clusterPassword: String,
shardGroupNum:Int,
batchSize: Int = 100000,
executorNum: Int = 20,linkTableHash: String)
(implicit client: T): Map[String, Long] = {

val schema = df.schema
import df.sparkSession.implicits._
//取表名的哈希值
var tableNameHashCode:Long = 1
if(linkTableHash != "off") {
tableNameHashCode = math.abs(tableName.hashCode.toLong)
}
val insertResults = df.repartition(executorNum).mapPartitions((partition: Iterator[org.apache.spark.sql.Row]) => {
//批量插入的sql语句
val insertSql = client.generateInsertStatement(schema, dbName, tableName)
//shardNum和与之对应的statement对应
var clusterShardNumStatementMap: Map[Int, PreparedStatement] = Map()
val connections = client.getBalancedConnectionMaptoShard(clusterName, dbName, targetHost, clusterUser, clusterPassword)
for (conn <- connections) {
clusterShardNumStatementMap += (conn._1 -> conn._2.prepareStatement(insertSql))
}

//集群内shard总数
val clusterTotalShardNum = clusterShardNumStatementMap.size

println(s"======> clusterTotalShardNum=$clusterTotalShardNum")

//shardGroupNum为用户传入的shard组内数量,需要进行合理性验证,groupNum为之后流程使用的数量
var groupNum = shardGroupNum
if (shardGroupNum > clusterTotalShardNum) {
groupNum = clusterTotalShardNum
} else if (shardGroupNum <= 0) {
groupNum = 1
}

//处理总条数
var totalInsert: Long = 0
//批次条数计数器
var counter = 0

//usedShardNumStatementMap为用到的shardNum与statement对应关系
var usedShardNumStatementMap: Map[Int, PreparedStatement] = Map()

//依次处理每条记录
for (row <- partition) {
counter += 1

//计算此条记录应该落在哪个分区上
val finalShardIndex = calShardIndex(shardColumn, tableNameHashCode, clusterTotalShardNum, groupNum, row)

if (!usedShardNumStatementMap.contains(finalShardIndex)) {
usedShardNumStatementMap += (finalShardIndex -> clusterShardNumStatementMap(finalShardIndex))
}

val currentPreparedStatement = usedShardNumStatementMap(finalShardIndex)
schema.foreach { f =>
val fieldName = f.name
val fieldType = f.dataType
val fieldIdx = row.fieldIndex(fieldName)
var fieldVal = row.get(fieldIdx)

if (fieldVal != null) {
try{
if(fieldVal.isInstanceOf[mutable.WrappedArray.ofRef[_]]){
val dataArray = fieldVal.asInstanceOf[mutable.WrappedArray.ofRef[_]].toArray
currentPreparedStatement.setObject(fieldIdx + 1, dataArray)
}else if (fieldType.isInstanceOf[MapType]) {
val data_map=fieldVal.asInstanceOf[scala.collection.immutable.Map[_,_]].toMap
currentPreparedStatement.setObject(fieldIdx + 1, scala.collection.JavaConversions.mapAsJavaMap(data_map))
}else{
currentPreparedStatement.setObject(fieldIdx + 1, fieldVal)
}
}catch {
case e: Exception => {
e.printStackTrace()
throw new Exception("Hive字段:" + fieldName + ":" + fieldVal + ":" + f.dataType + "写入异常")
}
}
} else {
val defVal = client.defaultNullValue(f.dataType, fieldVal)
currentPreparedStatement.setObject(fieldIdx + 1, defVal)
}
}

currentPreparedStatement.addBatch()

if (counter >= batchSize) {
for (s <- usedShardNumStatementMap) {
val r = s._2.executeBatch()
totalInsert += r.sum
}
counter = 0
usedShardNumStatementMap = Map()
}
}

if (counter > 0) {
for (s <- usedShardNumStatementMap) {
val r = s._2.executeBatch()
totalInsert += r.sum
}
counter = 0
usedShardNumStatementMap = Map()
}

for (conn <- connections) {
conn._2.close()
}

// return: Seq((host, insertCount))
List((client.getTargetHost, totalInsert)).toIterator

}).persist
println("rdd_count:" + insertResults.count())

val results = insertResults.collect()
insertResults.unpersist()

// aggregate insert results by hosts
results.groupBy(_._1).map(x => (x._1, x._2.map(_._2).sum))
}


def saveToClickHouseBalancedBySqlCluster[T <: ClickHouseClient](dbName: String, tableName: String,
clusterName: String, shardColumn: String,
targetHost: String, clusterUser: String, clusterPassword: String,
shardGroupNum:Int,
batchSize: Int = 100000,
clusterModeCurrency: Int,linkTableHash: String)
(implicit client: T): Map[String, Long] = {
val schema = df.schema
import df.sparkSession.implicits._
//取表名的哈希值
var tableNameHashCode:Long = 1
if(linkTableHash != "off") {
tableNameHashCode = math.abs(tableName.hashCode.toLong)
}
//获取集群中每个 shard->[replica ips...] 的信息

val clusterInfoList:List[(Int, Int, String)] =
client.getClusterInfo(clusterName, dbName, targetHost, clusterUser, clusterPassword)//(shardNo, replicaNo, Ip)
.sortWith((a,b)=> (a._1 * 100 + a._2)<(b._1 * 100 + b._2))//shard 加权后 与 replica 一起参与排序
//List[(shardNo, replicaNo, Ip)]转化成Map[(shardNo -> [Ip1,Ip2,...,Ipn])],并探活IP
val clusterInfoMap:Map[Int,List[String]] =
clusterInfoList.groupBy(x => x._1)
.map(x => (x._1 ->x._2.groupBy(y=>y._2).map(z => z._2(0)._3).filter(x => isHostConnectable(x, 8123)).toList))
.filter(x => x._2.size != 0)//去掉所有IP都联不通的shard
val shardTotalNum = clusterInfoMap.size//集群的shard总数量
val currencyTaskNum = clusterModeCurrency
val numPartitions = shardTotalNum * currencyTaskNum//散列分组应为shard数量的整数倍

clusterInfoMap.map(a => println("shard_replica_ip map:" + a))

var insertRdd = df.repartition(numPartitions, new Column(shardColumn))
var dateKey=""
if(shardColumn.equals("dt")||shardColumn.equals("dateTime")){
if(df.take(1).length>0){
dateKey =df.take(1)(0).getAs(shardColumn).toString
println("------------------------------hashKey"+dateKey)
insertRdd=df.repartition(numPartitions)
}
}

val insertResults = insertRdd.rdd.mapPartitionsWithIndex(
(partitionKey,rows) => {
var key:Int =partitionKey
if(shardColumn.equals("dt")||shardColumn.equals("dateTime")){
key=math.abs(dateKey.hashCode)
}
//计算每个task所写入的节点, 按照tableNameHashCode散列
val shardNum:Int = ((key/currencyTaskNum + tableNameHashCode) % shardTotalNum).toInt
val replicaNum:Int = key%shardTotalNum%clusterInfoMap.get(shardNum+1).size
val hostAddress:String = clusterInfoMap.get(shardNum + 1).get(replicaNum)

//开始写数
val insertSql = client.generateInsertStatement(schema, dbName, tableName)
val conn = client.getOneConnection(clusterName, dbName, hostAddress, clusterUser, clusterPassword)
val currentPreparedStatement = conn.prepareStatement(insertSql)
//处理总条数
var totalInsert: Long = 0
//批次条数计数器
var counter = 0

//依次处理每条记录
for (row <- rows) {
counter += 1

schema.foreach { f =>
val fieldName = f.name
val fieldIdx = row.fieldIndex(fieldName)
var fieldVal = row.get(fieldIdx)
val fieldType=f.dataType
if (fieldVal != null) {
try{
if(fieldVal.isInstanceOf[mutable.WrappedArray.ofRef[_]]){
val dataArray = fieldVal.asInstanceOf[mutable.WrappedArray.ofRef[_]].toArray
currentPreparedStatement.setObject(fieldIdx + 1, dataArray)
}else if (fieldType.isInstanceOf[MapType]) {
val data_map=fieldVal.asInstanceOf[scala.collection.immutable.Map[_,_]].toMap
currentPreparedStatement.setObject(fieldIdx + 1, scala.collection.JavaConversions.mapAsJavaMap(data_map))
}else{
currentPreparedStatement.setObject(fieldIdx + 1, fieldVal)
}
}catch {
case e: Exception => {
e.printStackTrace()
throw new Exception("Hive字段:" + fieldName + ":" + fieldVal + ":" + f.dataType + "写入异常")
}
}
} else {
val defVal = client.defaultNullValue(f.dataType, fieldVal)
currentPreparedStatement.setObject(fieldIdx + 1, defVal)
}
}

currentPreparedStatement.addBatch()

if (counter >= batchSize) {
val r = currentPreparedStatement.executeBatch()
totalInsert += r.sum
counter = 0
}
}

if (counter > 0) {
val r = currentPreparedStatement.executeBatch()
totalInsert += r.sum
counter = 0
}

conn.close()

List((hostAddress, totalInsert)).toIterator
}
).persist()

println("rdd_count:" + insertResults.count())

val results = insertResults.collect()
insertResults.unpersist()

// aggregate insert results by hosts
results.groupBy(_._1).map(x => (x._1, x._2.map(_._2).sum))

}


def saveToClickHouseBalancedBySqlMulityCluster[T <: ClickHouseClient](dbName: String, tableName: String,
clusterName: String, shardColumn: String,
targetHost: String, clusterUser: String,
clusterPassword: String,
shardGroupNum:Int,
portInfo:Map[String,Int],
batchSize: Int = 500000,
clusterModeCurrency: Int,
linkTableHash: String)
(implicit client: T): Map[String, Long] = {
val schema = df.schema
import df.sparkSession.implicits._
//取表名的哈希值
var tableNameHashCode:Long = 1
if(linkTableHash != "off") {
tableNameHashCode = math.abs(tableName.hashCode.toLong)
}
//获取集群中每个 shard->[replica ips...] 的信息

val clusterInfoList:List[(Int, Int, String,Int)] =
client.getClusterMulityInfo(clusterName, dbName, targetHost, clusterUser, clusterPassword)//(shardNo, replicaNo, Ip)
.sortWith((a,b)=> (a._1 * 100 + a._2)<(b._1 * 100 + b._2))//shard 加权后 与 replica 一起参与排序
//List[(shardNo, replicaNo, Ip,port)]转化成Map[(shardNo -> [Ip1,Ip2,...,Ipn])],并探活IP
val clusterInfoMap:Map[Int,List[(String,Int)]] = clusterInfoList.groupBy(x => x._1).map(x => (x._1 ->x._2.groupBy(y=>y._2).map(z => (z._2(0)._3,portInfo(z._2(0)._4.toString))).filter(x => isHostConnectable(x._1,x._2)).toList)).filter(x => x._2.size != 0)//去掉所有IP都联不通的shard
val shardTotalNum = clusterInfoMap.size//集群的shard总数量
val currencyTaskNum = clusterModeCurrency
val numPartitions = shardTotalNum * currencyTaskNum//散列分组应为shard数量的整数倍
clusterInfoMap.map(a => println("shard_replica_ip map:" + a))
var insertRdd = df.repartition(numPartitions, new Column(shardColumn))
var dateKey=""
if(shardColumn.equals("dt")||shardColumn.equals("dateTime")){
if(df.take(1).length>0){
dateKey =df.take(1)(0).getAs(shardColumn).toString
println("------------------------------hashKey"+dateKey)
insertRdd=df.repartition(numPartitions)
}

}

val insertResults =insertRdd.rdd.mapPartitionsWithIndex(
(partitionKey,rows) => {
var key:Int =partitionKey
if(shardColumn.equals("dt")||shardColumn.equals("dateTime")){
key=math.abs(dateKey.hashCode)
}

//计算每个task所写入的节点, 按照tableNameHashCode散列
val shardNum:Int = ((key/currencyTaskNum + tableNameHashCode) % shardTotalNum).toInt
val replicaNum:Int = key%shardTotalNum%clusterInfoMap.get(shardNum+1).size
val hostAddress:String = clusterInfoMap.get(shardNum + 1).get(replicaNum)._1
val port:Int =clusterInfoMap.get(shardNum + 1).get(replicaNum)._2

//开始写数
val insertSql = client.generateInsertStatement(schema, dbName, tableName)
val conn = client.getOneConnectionOfMulity(clusterName, dbName, hostAddress, port,clusterUser, clusterPassword)
val currentPreparedStatement = conn.prepareStatement(insertSql)
//处理总条数
var totalInsert: Long = 0
//批次条数计数器
var counter = 0

//依次处理每条记录
for (row <- rows) {
counter += 1

schema.foreach { f =>
val fieldName = f.name
val fieldType =f.dataType
val fieldIdx = row.fieldIndex(fieldName)
var fieldVal = row.get(fieldIdx)

if (fieldVal != null) {
try{
if(fieldVal.isInstanceOf[mutable.WrappedArray.ofRef[_]]){
val dataArray = fieldVal.asInstanceOf[mutable.WrappedArray.ofRef[_]].toArray
currentPreparedStatement.setObject(fieldIdx + 1, dataArray)
}else if (fieldType.isInstanceOf[MapType]) {
val data_map=fieldVal.asInstanceOf[scala.collection.immutable.Map[_,_]].toMap
currentPreparedStatement.setObject(fieldIdx + 1, scala.collection.JavaConversions.mapAsJavaMap(data_map))
}else{
currentPreparedStatement.setObject(fieldIdx + 1, fieldVal)
}
}catch {
case e: Exception => {
e.printStackTrace()
throw new Exception("Hive字段:" + fieldName + ":" + fieldVal + ":" + f.dataType + "写入异常")
}
}
} else {
val defVal = client.defaultNullValue(f.dataType, fieldVal)
currentPreparedStatement.setObject(fieldIdx + 1, defVal)
}
}

currentPreparedStatement.addBatch()

if (counter >= batchSize) {
val r = currentPreparedStatement.executeBatch()
totalInsert += r.sum
counter = 0
}
}

if (counter > 0) {
val r = currentPreparedStatement.executeBatch()
totalInsert += r.sum
counter = 0
}

conn.close()

List((hostAddress, totalInsert)).toIterator
}
).persist()

println("rdd_count:" + insertResults.count())

val results = insertResults.collect()
insertResults.unpersist()

// aggregate insert results by hosts
results.groupBy(_._1).map(x => (x._1, x._2.map(_._2).sum))

}

def isHostConnectable(host: String, port: Int): Boolean = {
val socket = new Socket
try
socket.connect(new InetSocketAddress(host, port))
catch {
case e: IOException =>
println(host + ":" +port + " 无法连通")
return false
} finally try
socket.close()
catch {
case e: IOException =>
println(host + ":" +port + " 无法连通")
return false
}
true
}

def saveToClickHouseBalancedByXml[T <: ClickHouseClient](dbName: String, tableName: String,
clusterName: String, shardColumn: String,
targetHost: String, clusterUser: String, clusterPassword: String,
shardGroupNum:Int,
batchSize: Int = 100000,
executorNum: Int = 20, reflect: Map[String, Seq[String]],linkTableHash: String)
(implicit client: T): Map[String, Long] = {

val schema = df.schema
import df.sparkSession.implicits._
//取表名的哈希值
var tableNameHashCode:Long = 1
if(linkTableHash != "off") {
tableNameHashCode = math.abs(tableName.hashCode.toLong)
}
val insertResults = df.repartition(executorNum).mapPartitions((partition: Iterator[org.apache.spark.sql.Row]) => {
//批量插入的sql语句
val insertSql = client.generateInsertStatement(schema, dbName, tableName)
//shardNum和与之对应的statement对应
var clusterShardNumStatementMap: Map[Int, PreparedStatement] = Map()
val connections = client.getBalancedConnectionMaptoShard(clusterName, dbName, targetHost, clusterUser, clusterPassword)
for (conn <- connections) {
clusterShardNumStatementMap += (conn._1 -> conn._2.prepareStatement(insertSql))
}

//集群内shard总数
val clusterTotalShardNum = clusterShardNumStatementMap.size

println(s"======> clusterTotalShardNum=$clusterTotalShardNum")

//shardGroupNum为用户传入的shard组内数量,需要进行合理性验证,groupNum为之后流程使用的数量
var groupNum = shardGroupNum
if (shardGroupNum > clusterTotalShardNum) {
groupNum = clusterTotalShardNum
} else if (shardGroupNum <= 0) {
groupNum = 1
}

//处理总条数
var totalInsert: Long = 0
//批次条数计数器
var counter = 0

//ck表列信息,Map[alias列名称, Seq[(type列值类型, split分隔符)]]
var mapper: Map[String, Seq[String]] = Map()
reflect.map(f => mapper += (f._2(0) -> Seq(f._2(1),f._2(2))))

//usedShardNumStatementMap为用到的shardNum与statement对应关系
var usedShardNumStatementMap: Map[Int, PreparedStatement] = Map()

//依次处理每条记录
for (row <- partition) {
counter += 1

//计算此条记录应该落在哪个分区上
val finalShardIndex = calShardIndexByRowValue(shardColumn, tableNameHashCode, clusterTotalShardNum, groupNum, row, mapper)

if (!usedShardNumStatementMap.contains(finalShardIndex)) {
usedShardNumStatementMap += (finalShardIndex -> clusterShardNumStatementMap(finalShardIndex))
}

val currentPreparedStatement = usedShardNumStatementMap(finalShardIndex)
schema.foreach { f =>
val fieldName = f.name
// val originalType = f.dataType
val targetType = mapper(f.name)(0)
val splitType = mapper(f.name)(1)
val fieldIdx = row.fieldIndex(fieldName)
var fieldVal = row.get(fieldIdx)

if (fieldVal != null) {
try{
fieldVal = changeArray(targetType,splitType, fieldIdx, row)
currentPreparedStatement.setObject(fieldIdx + 1, fieldVal)
}catch {
case e: Exception => {
e.printStackTrace()
throw new RuntimeException("Hive字段:" + fieldName + ":" + fieldVal + ":" + f.dataType + "转化为" + mapper(f.name)(0) + "类型异常")
}
}
} else {
val defVal = client.defaultNullSplitValue(client.xml2SparkType(mapper(f.name)(0)), splitType,fieldVal)
currentPreparedStatement.setObject(fieldIdx + 1, defVal)
}
}

currentPreparedStatement.addBatch()

if (counter >= batchSize) {
for (s <- usedShardNumStatementMap) {
val r = s._2.executeBatch()
totalInsert += r.sum
}
counter = 0
usedShardNumStatementMap = Map()
}
}

if (counter > 0) {
for (s <- usedShardNumStatementMap) {
val r = s._2.executeBatch()
totalInsert += r.sum
}
counter = 0
usedShardNumStatementMap = Map()
}

for (conn <- connections) {
conn._2.close()
}

// return: Seq((host, insertCount))
List((client.getTargetHost, totalInsert)).toIterator

}).persist
println("rdd_count:" + insertResults.count())

val results = insertResults.collect()
insertResults.unpersist()

// aggregate insert results by hosts
results.groupBy(_._1).map(x => (x._1, x._2.map(_._2).sum))
}

def calShardIndexByRowValue(shardColumn: String, tableNameHashCode: Long, clusterTotalShardNum: Int, groupNum: Int, row: org.apache.spark.sql.Row, mapper: Map[String, Seq[String]]): Int = {
val viceShardColumnList = shardColumn.split(",")
//取第一列计算起始shard值
var primaryShardColumnValue = row.getAs[String](viceShardColumnList(0))
if(primaryShardColumnValue==null || primaryShardColumnValue==""){
primaryShardColumnValue="null"
}
//取hash值
val primaryShardColumnHashCode = math.abs(primaryShardColumnValue.toString.hashCode.toLong)
val sum = primaryShardColumnHashCode + tableNameHashCode

//对于集群内shard总数进行取余,确定shard组内第一个应该落在哪个shard上,因为shard从1开始,所以需要加1
val firstShardNumInGroup = math.abs(sum % clusterTotalShardNum) + 1

//副分区列不为空时,对于副分区列值的哈希值进行累加,用于计算此记录落在shard组内的第几个位置
var sumOfViceShardColumnHashCode = 0L
var viceShardValidateMark = false
if (viceShardColumnList != null && viceShardColumnList.size > 1 && !viceShardColumnList(0).equals("")) {
for (index<-1 to viceShardColumnList.size-1) {
if (mapper.contains(viceShardColumnList(index))) {
viceShardValidateMark = true
val fieldIndex = row.fieldIndex(viceShardColumnList(index))
var fieldValue = row.get(fieldIndex)

if (fieldValue != null) {
try{
fieldValue = changeArray(mapper(viceShardColumnList(index))(0), mapper(viceShardColumnList(index))(1), fieldIndex, row)
} catch {
case e: Exception => {
e.printStackTrace()
throw new RuntimeException("Hive字段:" + shardColumn + ":" + fieldValue + "转化为" + mapper(viceShardColumnList(index))(0) + "类型异常")
}
}

sumOfViceShardColumnHashCode += math.abs(fieldValue.toString.hashCode.toLong)
}
}
}
} else {
sumOfViceShardColumnHashCode += (new util.Random).nextInt(groupNum).toLong
}


//副分区列值的哈希值和,取余,确定此条记录落在组内的位置
val shardIndexInGroup = math.abs(sumOfViceShardColumnHashCode % groupNum)

//计算此条记录最终落于哪个分区上,当分区编号大于集群分区编号最大值时,进行取余操作
var finalShardIndex = 1
if (firstShardNumInGroup + shardIndexInGroup > clusterTotalShardNum) {
finalShardIndex = ((firstShardNumInGroup + shardIndexInGroup) % clusterTotalShardNum).toInt
} else {
finalShardIndex = (firstShardNumInGroup + shardIndexInGroup).toInt
}

finalShardIndex
}


def calShardIndex(shardColumn: String, tableNameHashCode: Long, clusterTotalShardNum: Int, groupNum: Int, row: org.apache.spark.sql.Row): Int = {
val viceShardColumnList = shardColumn.split(",")
//取第一列计算起始shard值
var primaryShardColumnValue = row.getAs[String](viceShardColumnList(0))
if(primaryShardColumnValue==null || primaryShardColumnValue==""){
primaryShardColumnValue="null"
}
//取hash值
val primaryShardColumnHashCode = math.abs(primaryShardColumnValue.toString.hashCode.toLong)
val sum = primaryShardColumnHashCode + tableNameHashCode

//对于集群内shard总数进行取余,确定shard组内第一个应该落在哪个shard上,因为shard从1开始,所以需要加1
val firstShardNumInGroup = math.abs(sum % clusterTotalShardNum) + 1

//副分区列不为空时,对于副分区列值的哈希值进行累加,用于计算此记录落在shard组内的第几个位置
var sumOfViceShardColumnHashCode = 0L
if (viceShardColumnList != null && viceShardColumnList.size > 1 && !viceShardColumnList(0).equals("")) {
for (index<-1 to viceShardColumnList.size-1) {
val viceShardColumn =viceShardColumnList(index)
val fieldIndex = row.fieldIndex(viceShardColumn)
var fieldValue = row.get(fieldIndex)
sumOfViceShardColumnHashCode += math.abs(fieldValue.toString.hashCode.toLong)
}
} else {
sumOfViceShardColumnHashCode += (new util.Random).nextInt(groupNum).toLong
}


//副分区列值的哈希值和,取余,确定此条记录落在组内的位置
val shardIndexInGroup = math.abs(sumOfViceShardColumnHashCode % groupNum)

//计算此条记录最终落于哪个分区上,当分区编号大于集群分区编号最大值时,进行取余操作
var finalShardIndex = 1
if (firstShardNumInGroup + shardIndexInGroup > clusterTotalShardNum) {
finalShardIndex = ((firstShardNumInGroup + shardIndexInGroup) % clusterTotalShardNum).toInt
} else {
finalShardIndex = (firstShardNumInGroup + shardIndexInGroup).toInt
}

finalShardIndex
}

def change(originalType: DataType, targetType: String,splitType: String, valueIndx: Int, row: Row): Any = {
implicit val v = row.get(valueIndx)
ifArray(v, splitType, targetType)
}

def ifArray[T](value: Any, splitType: String, targetType: String) : Any = {
if (splitType == null || splitType == "")
toValueType(value.toString, targetType)
else {
val valueArray = value.toString.split(splitType)
val ckList = new ArrayBuffer[T]()
for (element <- valueArray) {
val value = toValueType(element, targetType)
ckList :+ value
}
ckList
}
}

def toValueType(value: String, targetType: String) : Any = targetType match {
case "int" => {
value.toInt
}
case "long" => {
value.toLong
}
case "string" => {
value
}
case "float" => {
value.toFloat
}
case "boolean" => {
value.toBoolean
}
case "double" => {
value.toDouble
}
case "date" => {
value
}
case "Date" => {
value
}
case _ => None
}



def changeArray(targetType:String,splitType: String ,valueIndx:Int,row:Row): Any =targetType match{
case "int"=>{
implicit val v =row.get(valueIndx)
if (splitType == null || splitType == "")
{v.toString.toInt}
else {
val valueArray = v.toString.split(splitType)
val ckList = new util.ArrayList[Int]()
for (element <- valueArray) {
val value = element.toInt
ckList.add(value)
}
ckList
}
}
case "long"=>{
implicit val v =row.get(valueIndx);
if (splitType == null || splitType == "")
{v.toString.toLong}
else{
val valueArray = v.toString.split(splitType)
val ckList = new util.ArrayList[Long]()
for (element <- valueArray) {
val value = element.toLong
ckList.add(value)
}
ckList
}
}
case "string"=>{
implicit val v =row.get(valueIndx);
if (splitType == null || splitType == "")
{v.toString}
else {
val valueArray = v.toString.split(splitType)
val ckList = new util.ArrayList[String]()
for (element <- valueArray) {
val value = element.toString
ckList.add(value)
}
ckList
}
}
case "float"=>{
implicit val v =row.get(valueIndx);

if (splitType == null || splitType == "")
{v.toString.toFloat}
else {
val valueArray = v.toString.split(splitType)
val ckList = new util.ArrayList[Float]()
for (element <- valueArray) {
val value = element.toInt
ckList.add(value)
}
ckList
}
}
case "boolean"=>{
implicit val v =row.get(valueIndx);
if (splitType == null || splitType == "")
{v.toString.toFloat}
else {
val valueArray = v.toString.split(splitType)
val ckList = new util.ArrayList[Boolean]()
for (element <- valueArray) {
val value = element.toBoolean
ckList.add(value)
}
ckList
}
}
case "double"=>{
implicit val v =row.get(valueIndx);

if (splitType == null || splitType == "")
{v.toString.toDouble}
else {
val valueArray = v.toString.split(splitType)
val ckList = new util.ArrayList[Double]()
for (element <- valueArray) {
val value = element.toDouble
ckList.add(value)
}
ckList
}
}
case "array"=>{
implicit val v =row.get(valueIndx);
val dataArray = v.asInstanceOf[mutable.WrappedArray.ofRef[_]].toArray
dataArray
}
case _ => None
}


def saveToClickHouse[T <: ClickHouseClient](dbName: String = "default", tableName: String,
batchSize: Int = 1)
(implicit client: T): Map[String, Long] = {
val schema = df.schema
// following code is going to be run on executors
val insertResults = df.rdd.mapPartitions((partition: Iterator[org.apache.spark.sql.Row]) => {
using(client.getConnection) { conn =>

val insertSql = client.generateInsertStatement(schema, dbName, tableName)
val statement = conn.prepareStatement(insertSql)
var totalInsert: Long = 0
var counter = 0

while (partition.hasNext) {
counter += 1
val row = partition.next()
schema.foreach { f =>
val fieldName = f.name
val fieldIdx = row.fieldIndex(fieldName)
val fieldVal = row.get(fieldIdx)
if (fieldVal != null)
statement.setObject(fieldIdx + 1, fieldVal)
else {
val defVal = client.defaultNullValue(f.dataType, fieldVal)
statement.setObject(fieldIdx + 1, defVal)
}
}
statement.addBatch()

if (counter >= batchSize) {
val r = statement.executeBatch()
totalInsert += r.sum
counter = 0
}
}

if (counter > 0) {
val r = statement.executeBatch()
totalInsert += r.sum
counter = 0
}

// return: Seq((host, insertCount))
List((client.getTargetHost, totalInsert)).toIterator
}
}).persist

println("rdd_count:" + insertResults.count())

val results = insertResults.collect()
insertResults.unpersist()

// aggregate insert results by hosts
results.groupBy(_._1).map(x => (x._1, x._2.map(_._2).sum))
}
}