Compare commits

...

8 Commits

13 changed files with 198 additions and 25 deletions

View File

@ -102,5 +102,31 @@ public class ModelController {
return OptResult.success(datasetList);
}
@Operation(summary = "获取模型训练信息")
@GetMapping("/getModelTrainInfo")
public OptResult getModelTrainInfo(Long id){
log.info("获取模型训练信息");
ModelTrainInfoVO modelTrainInfo = modelService.getModelTrainInfo(id);
return OptResult.success(modelTrainInfo);
}
@Operation(summary = "模型修改成训练中")
@PutMapping("/updateModelTrain")
public OptResult updateModelTrain(Long id){
log.info("模型修改成训练中");
modelService.updateModelTrain(id);
return OptResult.success();
}
@Operation(summary = "模型更新小版本")
@PutMapping("/updateModelVersionMinor")
public OptResult updateModelVersionMinor(@RequestBody ModelVersionDTO dto){
log.info("模型更新小版本");
modelService.updateModelVersionMinor(dto);
return OptResult.success();
}
}

View File

@ -1,12 +1,16 @@
package com.bipt.intelligentapplicationorchestrationservice.controller;
import com.bipt.intelligentapplicationorchestrationservice.config.IpConfig;
import com.bipt.intelligentapplicationorchestrationservice.entity.DeployRequest;
import com.bipt.intelligentapplicationorchestrationservice.mapper.ModelMapper;
import com.bipt.intelligentapplicationorchestrationservice.pojo.*;
import com.bipt.intelligentapplicationorchestrationservice.service.ModelDeployer;
import com.bipt.intelligentapplicationorchestrationservice.service.PublishService;
import com.bipt.intelligentapplicationorchestrationservice.util.NacosServiceUtil;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.transaction.annotation.Transactional;
@ -27,17 +31,29 @@ public class PublishController {
@Autowired
private NacosServiceUtil nacosServiceUtil;
@Autowired
private ModelMapper modelMapper;
@Autowired
private IpConfig ipConfig;
@Autowired
private ModelDeployer modelDeployer;
@PostMapping
@Operation(summary ="新增发布请求")
@Transactional
public OptResult<List<ServicePublishVO>> save(@RequestBody ServicePublishDTO servicePublishDTO) {
log.info("模型发布请求:{}", servicePublishDTO);
publishService.save(servicePublishDTO);
//todo 调用模型部署
//调用模型部署
DeployRequest request = new DeployRequest();
Long modelId = servicePublishDTO.getModelId();
ModelVersion modelVersion = modelMapper.selectById(modelId);
String modelConfig = modelVersion.getModelConfig();
//假设modelConfig只存GPU数据
request.setModelId(String.valueOf(modelId));
request.setRequiredMemory(Integer.parseInt(modelConfig));
modelDeployer.deploy(request);
// 获取前端传来的IP字符串
String ipListStr = servicePublishDTO.getIp();
if (ipListStr == null || ipListStr.trim().isEmpty()) {

View File

@ -5,14 +5,14 @@ import com.bipt.intelligentapplicationorchestrationservice.pojo.ModelLogVO;
public interface EvaluationMapper {
/*
* 查询模型评估日志详情
* @param id 模型评估日志id
* @param id 模型版本id
* @return 模型评估日志详情
*/
ModelLogVO selectLogDetail(Long id);
/*
* 更新模型评估日志状态(评估通过则上线)
* @param id 模型评估日志id
* @param id 模型版本id
* @param status 模型评估日志状态
*/
void update(Long id, Integer status);

View File

@ -76,4 +76,11 @@ public interface ModelMapper {
*/
@Select("select dataset_id,dataset_name from dataset")
List<DatasetEntity> listDataset();
/**
* 获取模型训练信息
* @param id 模型版本表id
* 返回模型训练信息
*/
ModelTrainInfoVO getModelTrainInfo(Long id);
}

View File

@ -17,7 +17,7 @@ import java.time.LocalDateTime;
@AllArgsConstructor
public class ModelEvaluation implements Serializable {
private Long id; // 评估记录id
private Long modelId; // 关联模型id
private Long modelVersionId; // 关联模型id,后续修改成了模型版本id
private LocalDateTime evaluationTime; // 评估时间
private String evaluationResult; // 评估结果
private String operator; // 评估操作人员

View File

@ -0,0 +1,18 @@
package com.bipt.intelligentapplicationorchestrationservice.pojo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ModelTrainInfoVO {
private Long id;
private Integer datasetId; // 数据集id
private String modelConfig; // 模型配置信息
private String dsPath;// 版本信息表id
private String dataPreHandleFile; // 数据预处理文件存储路径
}

View File

@ -12,6 +12,7 @@ import java.time.LocalDateTime;
@AllArgsConstructor
public class ModelVersionDTO {
private Long id; // 模型版本id
private Long modelId; // 模型id
private String version; // 模型版本
private Integer datasetId; // 数据集id
private String modelConfig; // 模型配置信息

View File

@ -3,8 +3,10 @@ package com.bipt.intelligentapplicationorchestrationservice.service;
import com.bipt.intelligentapplicationorchestrationservice.mapper.GpuResourceDao;
import com.bipt.intelligentapplicationorchestrationservice.exception.CacheInitException;
import com.bipt.intelligentapplicationorchestrationservice.entity.GpuResource;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.annotation.PostConstruct;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.RedisConnectionFailureException;
@ -18,7 +20,6 @@ import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
@Transactional // 添加类级别事务管理
@Component
public class CacheManager {
@Autowired
@ -27,6 +28,9 @@ public class CacheManager {
@Autowired
private GpuResourceDao gpuResourceDao;
@Autowired
private ObjectMapper objectMapper; // 注入ObjectMapper用于类型转换
private final ReentrantLock lock = new ReentrantLock();
@Value("${cache.redis-key-prefix:gpu:}")
@ -38,9 +42,9 @@ public class CacheManager {
@Value("${cache.init-batch-size:500}")
private int initBatchSize;
private static final Logger log = org.slf4j.LoggerFactory.getLogger(CacheManager.class);
private static final Logger log = LoggerFactory.getLogger(CacheManager.class);
// 全量加载(带分页和分布式锁)
@Transactional(propagation = Propagation.REQUIRED) // 方法级别覆盖
@PostConstruct
public void loadFullCache() {
if (tryLock()) {
@ -82,16 +86,12 @@ public class CacheManager {
// 带随机TTL的缓存设置
private void setCacheWithTTL(GpuResource entity) {
String key = buildKey(entity.getGPUId().toString());
GpuResource cached = (GpuResource) redisTemplate.opsForValue().get(key);
// 保留原有内存字段值
if (cached != null && cached.getGPUMemorySize() != null) {
entity.setGPUMemorySize(cached.getGPUMemorySize());
}
// 直接存储实体对象,确保类型一致性
redisTemplate.opsForValue().set(
key,
entity,
ttlBase + (int)(Math.random() * 600), // 随机TTL防止雪崩
ttlBase + (int)(Math.random() * 600),
TimeUnit.SECONDS
);
}
@ -114,6 +114,7 @@ public class CacheManager {
private void unlock() {
lock.unlock();
}
// 分页加载入口
public void loadFullCache(int batchSize) {
int page = 0;
@ -121,12 +122,11 @@ public class CacheManager {
List<GpuResource> batch = gpuResourceDao.findByPage(page * batchSize, batchSize);
if (batch.isEmpty()) break;
batch.forEach(this::refreshWithRetry); // 带重试的刷新逻辑
batch.forEach(this::refreshWithRetry);
page++;
}
}
// 带重试机制的缓存刷新
public void refreshWithRetry(GpuResource entity) {
try {
@ -135,7 +135,7 @@ public class CacheManager {
// 3次重试逻辑
for (int i = 0; i < 3; i++) {
try {
log.info("重试第 {} 次", i + 1); // 添加日志
log.info("重试第 {} 次", i + 1);
Thread.sleep(1000);
setCacheWithTTL(entity);
return;
@ -148,7 +148,6 @@ public class CacheManager {
Thread.currentThread().interrupt();
}
}
}
}
@ -162,8 +161,35 @@ public class CacheManager {
redisTemplate.delete(key);
}
// 修改获取缓存的方法,增加类型安全处理
@SuppressWarnings("unchecked")
public GpuResource getFromCache(String gpuId) {
return (GpuResource) redisTemplate.opsForValue().get("gpu:" + gpuId);
}
String key = buildKey(gpuId);
Object value = redisTemplate.opsForValue().get(key);
}
// 处理可能的类型不匹配问题
if (value == null) {
return null;
}
try {
// 优先尝试直接转换
if (value instanceof GpuResource) {
return (GpuResource) value;
}
// 如果是LinkedHashMap使用ObjectMapper转换
else if (value instanceof java.util.LinkedHashMap) {
return objectMapper.convertValue(value, GpuResource.class);
}
// 其他情况尝试序列化后反序列化适用于JSON存储场景
else {
// 先序列化为JSON字符串再反序列化为对象
String json = objectMapper.writeValueAsString(value);
return objectMapper.readValue(json, GpuResource.class);
}
} catch (Exception e) {
log.error("获取缓存时类型转换失败key: {}, valueType: {}", key, value.getClass().getName(), e);
return null;
}
}
}

View File

@ -72,6 +72,10 @@ public class ModelServiceImpl implements ModelService {
return modelVOList;
}
/**
* 查询模型详情
* @param id
*/
@Override
public ModelVersion detail(Long id) {
log.info("查询模型详情");
@ -79,13 +83,23 @@ public class ModelServiceImpl implements ModelService {
return modelVersion;
}
/**
* 更新模型
* @param dto
*/
@Override
public void updateModel(ModelVersionDTO dto) {
// 更新模型还需要更新操作人和时间
// TODO: 更新模型还需要更新操作人和时间
log.info("更新模型");
dto.setCreateTime(LocalDateTime.now());
dto.setUpdateTime(LocalDateTime.now());
modelMapper.update(dto);
}
/**
* 删除模型版本
* @param id
*/
@Override
public void deleteModelVersion(Long id) {
log.info("删除模型版本");
@ -149,6 +163,9 @@ public class ModelServiceImpl implements ModelService {
log.info("模型生命周期更新成功,新状态为: {}", targetLifeCycle);
}
/**
* 获取模型生命周期列表
*/
@Override
public List<Map<String, String>> listLifeCycle() {
return Arrays.stream(ModelLifecycle.values())
@ -159,6 +176,9 @@ public class ModelServiceImpl implements ModelService {
.collect(Collectors.toList());
}
/**
* 获取模型数据集列表
*/
@Override
public List<DatasetEntity> listDataset() {
List<DatasetEntity> datasetEntityList = modelMapper.listDataset();
@ -166,4 +186,41 @@ public class ModelServiceImpl implements ModelService {
}
/**
* 获取模型训练信息
* @param id
*/
@Override
public ModelTrainInfoVO getModelTrainInfo(Long id) {
ModelTrainInfoVO modelTrainInfoVO = modelMapper.getModelTrainInfo(id);
return modelTrainInfoVO;
}
/**
* 模型训练(把模型修改成训练中)
* @param id
*/
@Override
public void updateModelTrain(Long id) {
// 更新当前模型的生命周期为训练中
modelMapper.updateLifeCycleById(id, ModelLifecycle.TRAINING.getDbValue());
}
/**
* 模型小版本更新
* @param dto
*/
@Override
public void updateModelVersionMinor(ModelVersionDTO dto) {
// 更新模型小版本(其实是新增一个小版本)
ModelVersion modelVersion = new ModelVersion();
BeanUtils.copyProperties(dto, modelVersion, "id", "modelId");
modelVersion.setModelId(dto.getModelId()); // 把模型id设置成该模型版本关联的模型id
modelVersion.setCreateTime(LocalDateTime.now());
modelVersion.setUpdateTime(LocalDateTime.now());
modelVersion.setOperateUser("zs");
// TODO: 后续可能还需要更新操作人
modelMapper.insertModelVersion(modelVersion);
}
}

View File

@ -23,4 +23,10 @@ public interface ModelService {
List<Map<String, String>> listLifeCycle();
List<DatasetEntity> listDataset();
ModelTrainInfoVO getModelTrainInfo(Long id);
void updateModelTrain(Long id);
void updateModelVersionMinor(ModelVersionDTO dto);
}

View File

@ -0,0 +1,5 @@
# 阿里云OSS配置
aliyun.oss.endpoint=oss-cn-beijing.aliyuncs.com
aliyun.oss.bucketName=ipz-nh
aliyun.oss.accessKeyId=LTAI5tBeto7V7BPWBcCjeP7A
aliyun.oss.accessKeySecret=bjQGt2G4J5yetxuY5cT5ZnKnIOqe4O

View File

@ -9,7 +9,7 @@
from model_log m1,
model_info m2,
model_version m3
where m1.model_id=m2.id and m3.model_id=m2.id and m1.model_id = #{id}
where m1.model_version_id=m3.id and m3.model_id=m2.id and m1.model_version_id = #{id}
</select>
<!--更新模型信息(目前只更新模型是否上线,后续如果更多需求可优化>-->

View File

@ -31,8 +31,8 @@
<!--查询模型详细信息-->
<select id="selectById" resultType="com.bipt.intelligentapplicationorchestrationservice.pojo.ModelVersion">
SELECT
t1.model_name,
t2.version, t2.dataset_id, t2.model_config,
t1.model_name, t1.id modelId,
t2.version, t2.dataset_id, t2.model_config, t2.id,
t2.model_path, t2.status, t2.create_time, t2.update_time, t2.model_size,
t2.data_pre_handle_file, t2.model_super_args, t2.model_args_size, t2.model_source_code_url, t2.model_file,
t2.model_design_document, t2.life_cycle, t2.operate_user
@ -62,4 +62,15 @@
</set>
WHERE id = #{id}
</update>
<!--获取模型训练信息-->
<select id="getModelTrainInfo" resultType="com.bipt.intelligentapplicationorchestrationservice.pojo.ModelTrainInfoVO">
select m1.dataset_id,
m1.id,
m1.model_config,
d2.ds_path,
m1.data_pre_handle_file
from model_version m1,dataset d2
where m1.dataset_id=d2.dataset_id and m1.id=#{id}
</select>
</mapper>