说明
- 最近工作中遇到了要将大量数据入库的场景。查阅了一些资料,最终利用juc的
CyclicBarrier
自己实现了一版。脱敏后将代码和大家进行分享。大家有更好的思路欢迎进行探讨。 - 我这里为了充分利用服务器资源,采用多线程+批量提交的方式来实现。但是因为
@Transactional
在多线程下是不生效的.所以需要自己控制子线程的事务提交与回滚。 - 这里并不局限于对mysql的插入.可以看到将数据入库的操作放到了回调函数中.这意味着这是一个将数据切割成小块并进行插入的逻辑。那么只要符合这个基本要求都是可以插入的。因为时间有限并没有做封装,后续有时间后会封装一下上传到gitee中。
- 需要注意的是,我这里数据库是安装在本地的。在实际情况中,受限于网络传输以及IO速度限制。实际执行时间应该会长于实验出来的时间。尽管如此,11秒的时间插入100万条数据,速度依旧还是杠杠的。
配置信息
代码
数据库结构
- 这里采用了bigint数据类型,且在插入数据的时候由客户端进行生成。
- 如果使用其他类型需要保证单调递增,否则会造成频繁叶分裂拖慢运行速度。
CREATE TABLE `sql_test`.`t_equipment` (
`id` bigint(32) NOT NULL COMMENT 'id',
`create_time` datetime(0) NULL DEFAULT NULL COMMENT '创建时间',
`create_by` char(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '创建人',
`update_time` datetime(0) NULL DEFAULT NULL COMMENT '更新时间',
`update_by` char(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '更新人',
`del_flag` char(1) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '数据有效性(0:正常,2:删除)',
`remark` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '备注',
`equipment_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '设备名称',
`room_id` bigint(32) NULL DEFAULT NULL COMMENT '房间ID',
`equipment_type` char(11) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '0电表1水表2气表',
PRIMARY KEY (`id`)
);
实体类
package com.ruoyi.equipment.domain.TEquipment;
import com.fasterxml.jackson.annotation.JsonFormat;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
import java.util.Date;
public class TEquipment {
private static final long serialVersionUID = 1L;
/**
* id
*/
private Long id;
/**
* 数据有效性(0:正常,2:删除)
*/
private String delFlag;
/**
* 设备名称
*/
private String equipmentName;
/**
* 房间ID
*/
private Long roomId;
/**
* 0电表1水表2气表
*/
private String equipmentType;
public void setId(Long id) {
this.id = id;
}
public Long getId() {
return id;
}
public void setDelFlag(String delFlag) {
this.delFlag = delFlag;
}
public String getDelFlag() {
return delFlag;
}
public void setEquipmentName(String equipmentName) {
this.equipmentName = equipmentName;
}
public String getEquipmentName() {
return equipmentName;
}
public void setRoomId(Long roomId) {
this.roomId = roomId;
}
public Long getRoomId() {
return roomId;
}
public void setEquipmentType(String equipmentType) {
this.equipmentType = equipmentType;
}
public String getEquipmentType() {
return equipmentType;
}
/**
* 创建者
*/
private String createBy;
/**
* 创建时间
*/
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
private Date createTime;
/**
* 更新者
*/
private String updateBy;
/**
* 更新时间
*/
@JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss")
private Date updateTime;
/**
* 备注
*/
private String remark;
public String getCreateBy() {
return createBy;
}
public void setCreateBy(String createBy) {
this.createBy = createBy;
}
public Date getCreateTime() {
return createTime;
}
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
public String getUpdateBy() {
return updateBy;
}
public void setUpdateBy(String updateBy) {
this.updateBy = updateBy;
}
public Date getUpdateTime() {
return updateTime;
}
public void setUpdateTime(Date updateTime) {
this.updateTime = updateTime;
}
public String getRemark() {
return remark;
}
public void setRemark(String remark) {
this.remark = remark;
}
public String toString() {
return new ToStringBuilder(this, ToStringStyle.MULTI_LINE_STYLE)
.append("id", getId())
.append("createTime", getCreateTime())
.append("createBy", getCreateBy())
.append("updateTime", getUpdateTime())
.append("updateBy", getUpdateBy())
.append("delFlag", getDelFlag())
.append("remark", getRemark())
.append("equipmentName", getEquipmentName())
.append("roomId", getRoomId())
.append("equipmentType", getEquipmentType())
.toString();
}
}
}
mapper.java
/**
* 设备Mapper接口
*
* @author ruoyi
* @date 2025-02-13
*/
public interface TEquipmentMapper
{
/**
* 查询设备
*
* @return 设备
*/
public Integer insertBatchTEquipment(List<TEquipment> dataList);
/**
* 新增设备
*
* @param tEquipment 设备
* @return 结果
*/
public int insertTEquipment(TEquipment tEquipment);
}
mapper.xml
<insert id="insertBatchTEquipment">
insert into t_equipment
(id,equipment_name,room_id,equipment_type,del_flag,create_time)
values
<foreach collection="list" item="item" separator="," >
(#{item.id},#{item.equipmentName}, #{item.roomId} ,#{item.equipmentType},#{item.delFlag},#{item.createTime})
</foreach>
</insert>
<insert id="insertTEquipment" parameterType="com.ruoyi.equipment.domain.TEquipment" useGeneratedKeys="true" keyProperty="id">
insert into t_equipment
<trim prefix="(" suffix=")" suffixOverrides=",">
<if test="createTime != null">create_time,</if>
<if test="createBy != null">create_by,</if>
<if test="updateTime != null">update_time,</if>
<if test="updateBy != null">update_by,</if>
<if test="delFlag != null">del_flag,</if>
<if test="remark != null">remark,</if>
<if test="equipmentName != null">equipment_name,</if>
<if test="roomId != null">room_id,</if>
<if test="equipmentType != null">equipment_type,</if>
</trim>
<trim prefix="values (" suffix=")" suffixOverrides=",">
<if test="createTime != null">#{createTime},</if>
<if test="createBy != null">#{createBy},</if>
<if test="updateTime != null">#{updateTime},</if>
<if test="updateBy != null">#{updateBy},</if>
<if test="delFlag != null">#{delFlag},</if>
<if test="remark != null">#{remark},</if>
<if test="equipmentName != null">#{equipmentName},</if>
<if test="roomId != null">#{roomId},</if>
<if test="equipmentType != null">#{equipmentType},</if>
</trim>
</insert>
application-dev.yaml
- 这里为了方便后续做测试。需要将
maxWait
配置为60000.以及将maxActive
,initial-size
配置为3000.
server:
port: 18080
spring:
datasource:
type: com.alibaba.druid.pool.DruidDataSource
driverClassName: com.mysql.cj.jdbc.Driver
druid:
# 主库数据源
master:
url: jdbc:mysql://localhost:3306/sql_test?rewriteBatchedStatements=true&useUnicode=true&characterEncoding=utf8&zeroDateTimeBehavior=convertToNull&useSSL=true&serverTimezone=GMT%2B8
username: root
password: root
initial-size: 3000
min-idle: 50
maxActive: 3000
maxWait: 60000
redis:
# 地址
host: localhost
port: 6379
# 数据库索引
database: 6
# 密码
password:
配置mysql的最大连接数
- 首先查询mysql的最大连接数
SHOW VARIABLES LIKE 'max_connections';
# 临时修改,重启后会失效
SET GLOBAL max_connections = 3000;
# 永久修改可在(linux:my.cnf ,window:my.ini)文件中修改
[mysqld]
max_connections = 3000
# 之后重启mysql服务
sudo systemctl restart mysql
最终效果
核心业务类
说明
- 这里主要借助于juc的
CyclicBarrier
进行实现.来用于将线程池中的所有子线程进行协调实现全局回滚以及提交。 - 借助于
CountDownLatch
实现主线程与所有子线程的协调。
package com.ruoyi.test_sql;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
@Service
public class TEquipmentBatchService {
protected final Logger log = LoggerFactory.getLogger(TEquipmentBatchService.class);
@Autowired
private PlatformTransactionManager transactionManager;
/**
* @param threadPoolTaskExecutor 线程池
* @param dataList 源数据
* @param pageSize 切割size
* @param dataConsumerMapping 数据映射对象
* @param 类型
*/
public <T> void splitDataBatchInsert(ThreadPoolTaskExecutor threadPoolTaskExecutor, List<T> dataList, int pageSize, Consumer<List<T>> dataConsumerMapping) throws InterruptedException {
// 线程池核心数目
int corePoolSize = threadPoolTaskExecutor.getCorePoolSize();
// 切割list
List<List<T>> splitDataLists = splitList(dataList, pageSize);
// 切割后的尺寸
int threadNum = splitDataLists.size();
// 线程池核心数小于切割size
// 这里实现的逻辑是先将执行完插入的线程挂起。所以会使用线程池的核心数目
if (corePoolSize < threadNum) {
throw new IllegalStateException("线程池核心数小于切割所需数目!");
}
// 全局事务是否成功
AtomicBoolean isSuccess = new AtomicBoolean(true);
// 主线程与子线程沟通
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
// 循环屏障,子线程之间互相沟通
CyclicBarrier barrier = new CyclicBarrier(threadNum, () -> {
// 成功数目 == 子线程数目,认为全局事务执行成功
if (!isSuccess.get()) {
log.info("发起全局回滚");
return;
}
log.info("发起全局提交");
});
// 遍历切割List
for (List<T> splitDataList : splitDataLists) {
// 加入线程池
threadPoolTaskExecutor.execute(() -> {
// 线程事务状态
DefaultTransactionDefinition transactionDefinition = new DefaultTransactionDefinition();
TransactionStatus status = transactionManager.getTransaction(transactionDefinition);
try {
// 执行回调
dataConsumerMapping.accept(splitDataList);
} catch (Exception e) {
e.printStackTrace();
isSuccess.set(false);
} finally {
try {
// 当前线程等待
barrier.await();
} catch (Exception e) {
e.printStackTrace();
isSuccess.set(false);
}
}
// 如果执行成功则继续
if (!isSuccess.get()) {
transactionCommitOrRollback(status, transactionManager::rollback);
} else {
transactionCommitOrRollback(status, transactionManager::commit);
}
// 主线程阻塞
countDownLatch.countDown();
});
}
try {
countDownLatch.await();
log.info("主线程停止阻塞!");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/**
* @param threadPoolTaskExecutor 线程池对象
* @param dataList 源数据
* @param pageSize 切割尺寸
* @param dataMapper 回调函数
* @param 泛型
*/
public <T> void splitDataListBatchInsert(ThreadPoolTaskExecutor threadPoolTaskExecutor,
List<T> dataList,
int pageSize,
Consumer<List<T>> dataMapper) {
// 线程池核心数目
int corePoolSize = threadPoolTaskExecutor.getCorePoolSize();
// 切割list
List<List<T>> splitDataLists = splitList(dataList, pageSize);
// 切割后的尺寸
int threadNum = splitDataLists.size();
// 线程池核心数小于切割size
// 这里实现的逻辑是先将执行完插入的线程挂起。所以会使用线程池的核心数目
if (corePoolSize < threadNum) {
throw new IllegalStateException("线程池核心数小于切割所需数目!");
}
// 全局事务是否成功
AtomicBoolean isSuccess = new AtomicBoolean(true);
// 主线程与子线程沟通
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
// 循环屏障,子线程之间互相沟通
CyclicBarrier barrier = new CyclicBarrier(threadNum, () -> {
// 成功数目 == 子线程数目,认为全局事务执行成功
if (!isSuccess.get()) {
log.info("发起全局回滚");
return;
}
log.info("发起全局提交");
});
// 遍历切割List
for (List<T> splitDataList : splitDataLists) {
// 加入线程池
threadPoolTaskExecutor.execute(() -> {
// 线程事务状态
DefaultTransactionDefinition transactionDefinition = new DefaultTransactionDefinition();
TransactionStatus status = transactionManager.getTransaction(transactionDefinition);
try {
// 执行回调
dataMapper.accept(splitDataList);
} catch (Exception e) {
e.printStackTrace();
isSuccess.set(false);
} finally {
try {
// 当前线程等待
barrier.await();
} catch (Exception e) {
e.printStackTrace();
isSuccess.set(false);
}
}
// 如果执行成功则继续
if (!isSuccess.get()) {
transactionCommitOrRollback(status, transactionManager::rollback);
return;
}
transactionCommitOrRollback(status, transactionManager::commit);
// 主线程阻塞
countDownLatch.countDown();
});
}
try {
countDownLatch.await();
log.info("阻塞完毕!");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/**
* 提交或者回滚事务
*
* @param transactionStatus TransactionStatus
* @param consumer 消费者接口
*/
public void transactionCommitOrRollback(TransactionStatus transactionStatus, Consumer<TransactionStatus> consumer) {
if (!transactionStatus.isCompleted()) {
consumer.accept(transactionStatus);
}
}
/**
* 切割源数据集合
*
* @param dataList 源数据集合
* @param pageSize 切割后List大小
* @param 泛型
* @return 切割后的List
*/
public <T> List<List<T>> splitList(List<T> dataList, int pageSize) {
int count = dataList.size();
// 计算需要切割尺寸
int threadNum = count % pageSize == 0 ? count / pageSize : count / pageSize + 1;
List<List<T>> splitDataList = new ArrayList<>();
for (int i = 0; i < threadNum; i++) {
int startIndex = i * pageSize;
int endIndex = Math.min(count, (i + 1) * pageSize);
// 根据下标切割
List<T> subList = dataList.subList(startIndex, endIndex);
splitDataList.add(subList);
}
// 返回
return splitDataList;
}
}
测试类
package com.ruoyi.test_sql;
import com.ruoyi.common.utils.DateUtils;
import com.ruoyi.equipment.domain.TEquipment;
import com.ruoyi.equipment.mapper.TEquipmentMapper;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@SpringBootTest
public class TEquipmentBatchServiceTest {
protected final Logger logger = LoggerFactory.getLogger(TEquipmentBatchServiceTest.class);
@Autowired
private TEquipmentBatchService tequipmentBatchInsert;
@Autowired
private TEquipmentMapper tEquipmentMapper;
private int pageSize = 3500;
private int totalCount = 1000000;
@Test
void testSplitDataBatchInsert() throws InterruptedException {
long startTime = System.currentTimeMillis();
// 模拟数据
List<TEquipment> dataList = generateDataList();
ThreadPoolTaskExecutor threadPoolTaskExecutor = getThreadPoolTaskExecutor();
tequipmentBatchInsert.splitDataBatchInsert(threadPoolTaskExecutor,dataList, pageSize, dataList1 -> {
tEquipmentMapper.insertBatchTEquipment(dataList1);
throwException();
});
long endTime = System.currentTimeMillis();
logger.info("用时--------{}",endTime - startTime);
}
/**
* 获取线程池
* @return ThreadPoolTaskExecutor
*/
public ThreadPoolTaskExecutor getThreadPoolTaskExecutor() {
int corePoolSize = calCorePoolSize(totalCount, pageSize);
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(corePoolSize);
threadPoolTaskExecutor.setMaxPoolSize(corePoolSize);
threadPoolTaskExecutor.setQueueCapacity(100);
threadPoolTaskExecutor.setThreadNamePrefix("batch-insert-thread-");
threadPoolTaskExecutor.initialize();
return threadPoolTaskExecutor;
}
/**
* 计算核心线程数
* @param totalCount 总数目
* @param pageSize 每次插入的数目
* @return int 核心线程数
*/
public int calCorePoolSize(int totalCount,int pageSize) {
return totalCount % pageSize == 0 ? totalCount / pageSize : totalCount / pageSize + 1;
}
/**
* 生成数据
* @return
*/
public List<TEquipment> generateDataList() {
// 模拟数据
return Stream.iterate(0, i -> i + 1).limit(totalCount)
.map(this::genderateTequipment)
.collect(Collectors.toList());
}
/**
* 模拟生成TEquipment对象
* @param i 种子值
* @return TEquipment
*/
public TEquipment genderateTequipment(Integer i) {
// 模拟TEquipment数据
TEquipment tEquipment = new TEquipment();
tEquipment.setId(Long.valueOf(i));
tEquipment.setEquipmentName("equ:"+i);
tEquipment.setDelFlag("0");
tEquipment.setCreateTime(DateUtils.getNowDate());
tEquipment.setRoomId((long) new Random().nextInt(10));
return tEquipment;
}
/**
* 模拟抛出异常
*/
public void throwException(){
int i = new Random().nextInt(10);
if(i == 1){
throw new RuntimeException("batch insert error");
}
}
}
测试说明
- 模拟异常时,将这行代码打开即可.
- 单纯测试插入速度注释掉这行即可
效果
无错误发生
- 可以达到11秒的速度插入完毕。
- 我这里使用的参数是
100 0000
,一次插入3000
条。并不是每个线程处理的数据越多越好,大家可根据自己设备以及业务情况自行调整。
模拟出现异常