diff --git a/pom.xml b/pom.xml index f1094d8..957533f 100644 --- a/pom.xml +++ b/pom.xml @@ -185,7 +185,12 @@ jaxb-runtime 2.3.3 - + + junit + junit + test + + diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/AlgorithmInfoController.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/AlgorithmInfoController.java index aee19fe..9d4c466 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/AlgorithmInfoController.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/AlgorithmInfoController.java @@ -14,7 +14,10 @@ import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; +import java.util.ArrayList; import java.util.List; +import java.util.Map; + @Tag(name ="算法创建相关接口") @RestController @RequestMapping("/api/algorithm") @@ -31,16 +34,16 @@ public class AlgorithmInfoController { @GetMapping("/{id}") public ResponseEntity getById(@PathVariable Long id) { AlgorithmInfo algorithmInfo = algorithmInfoService.getById(id); - return algorithmInfo != null ? - ResponseEntity.ok(algorithmInfo) : + return algorithmInfo != null ? + ResponseEntity.ok(algorithmInfo) : ResponseEntity.notFound().build(); } @GetMapping("/name/{algorithmName}") public ResponseEntity getByName(@PathVariable String algorithmName) { AlgorithmInfo algorithmInfo = algorithmInfoService.getByName(algorithmName); - return algorithmInfo != null ? - ResponseEntity.ok(algorithmInfo) : + return algorithmInfo != null ? + ResponseEntity.ok(algorithmInfo) : ResponseEntity.notFound().build(); } @@ -56,18 +59,18 @@ public class AlgorithmInfoController { if (!algorithmInfoService.validateAlgorithmInfo(algorithmInfo)) { return ResponseEntity.badRequest().body("Invalid algorithm information"); } - + boolean success = algorithmInfoService.update(algorithmInfo); - return success ? - ResponseEntity.ok("Update successful") : + return success ? + ResponseEntity.ok("Update successful") : ResponseEntity.badRequest().body("Update failed"); } @DeleteMapping("/{id}") public ResponseEntity delete(@PathVariable Long id) { boolean success = algorithmInfoService.delete(id); - return success ? - ResponseEntity.ok("Delete successful") : + return success ? + ResponseEntity.ok("Delete successful") : ResponseEntity.badRequest().body("Delete failed"); } @@ -103,11 +106,37 @@ public class AlgorithmInfoController { * 算法运行 */ @PostMapping("/run/{id}") - @Operation(summary = "运行") - public OptResult run(@PathVariable Long id,@RequestBody String param){ - log.info("运行",id); - String result = algorithmInfoService.run(id,param); - return OptResult.success("运行成功"+result); + @Operation(summary = "运行算法") + public OptResult run(@PathVariable Long id, @RequestBody String param) { + log.info("运行算法 ID: {}", id); + try { + AlgorithmInfo algorithm = algorithmInfoService.getById(id); + if (algorithm == null) { + return OptResult.error("算法不存在"); + } + + // 1. 解析前端传入的参数(JSON格式) + Map paramMap = objectMapper.readValue(param, Map.class); + + // 2. 从参数中提取实际需要传递给Python脚本的参数列表 + // 示例:假设前端传入 {"args": [3, 0, 8, 7, 2, 1, 9, 4]} + List args = new ArrayList<>(); + if (paramMap.containsKey("args")) { + List argList = (List) paramMap.get("args"); + for (Object arg : argList) { + args.add(arg.toString()); + } + } + + // 3. 调用Service执行Python脚本并获取结果 + String result = algorithmInfoService.run(algorithm.getFilePath(), args); + + // 4. 返回结构化结果 + return OptResult.success("运行结果"+result); + } catch (Exception e) { + log.error("算法运行失败", e); + return OptResult.error("算法运行失败: " + e.getMessage()); + } } /** * 前端列表返回算法名称 @@ -118,4 +147,4 @@ public class AlgorithmInfoController { return algorithmInfoService.getAllNames(); } -} \ No newline at end of file +} \ No newline at end of file diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/PublishController.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/PublishController.java index 8ad2089..468754b 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/PublishController.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/PublishController.java @@ -3,7 +3,9 @@ package com.bipt.intelligentapplicationorchestrationservice.controller; import com.bipt.intelligentapplicationorchestrationservice.config.IpConfig; import com.bipt.intelligentapplicationorchestrationservice.entity.DeployRequest; import com.bipt.intelligentapplicationorchestrationservice.entity.ModelSelectVO; +import com.bipt.intelligentapplicationorchestrationservice.enumeration.ServiceStatus; import com.bipt.intelligentapplicationorchestrationservice.mapper.ModelMapper; +import com.bipt.intelligentapplicationorchestrationservice.mapper.PublishMapper; import com.bipt.intelligentapplicationorchestrationservice.pojo.*; import com.bipt.intelligentapplicationorchestrationservice.service.ModelDeployer; import com.bipt.intelligentapplicationorchestrationservice.service.PublishService; @@ -39,20 +41,26 @@ public class PublishController { @Autowired private ModelDeployer modelDeployer; + @Autowired + private PublishMapper publishMapper; @PostMapping @Operation(summary ="新增发布请求") @Transactional public OptResult> save(@RequestBody ServicePublishDTO servicePublishDTO) { log.info("模型发布请求:{}", servicePublishDTO); + Long id = servicePublishDTO.getModelId(); + Long ModelId = publishService.getModelId(id); + servicePublishDTO.setModelId(ModelId); + servicePublishDTO.setStatus(ServiceStatus.ONLINE.getCode()); publishService.save(servicePublishDTO); //调用模型部署 DeployRequest request = new DeployRequest(); - Long modelId = servicePublishDTO.getModelId(); - ModelVersion modelVersion = modelMapper.selectByModelId(modelId); +/* Long modelId = servicePublishDTO.getModelId();*/ + ModelVersion modelVersion = publishMapper.selectByModelVersionId(id); String modelConfig = modelVersion.getModelConfig(); //假设modelConfig只存GPU数据 - request.setModelId(String.valueOf(modelId)); + request.setModelId(String.valueOf(ModelId)); request.setRequiredMemory(Integer.parseInt(modelConfig)); modelDeployer.deploy(request); // 获取前端传来的IP字符串 @@ -108,10 +116,143 @@ public class PublishController { log.info("返回列表;{}",ips); return OptResult.success(ips); } + + @GetMapping("/config/ids") - public OptResult> getModelNames(){ - List modelSelectVOS = publishService.getModelNames(); - log.info("获取到模型列表:{}",modelSelectVOS); + public OptResult> getModelNames() { + // 只获取状态为“在线”的模型列表(筛掉已下线的服务) + List modelSelectVOS = publishService.getOnlineModelNames(); + /*List modelNames = publishService.getModelNames();*/ + + log.info("获取到在线模型列表:{}", modelSelectVOS); return OptResult.success(modelSelectVOS); } + @PostMapping("/online/{serviceId}") + @Operation(summary = "上线已下线的服务") + @Transactional + public OptResult onlineService(@PathVariable Long serviceId) { + log.info("上线服务请求: {}", serviceId); + + // 1. 从数据库获取服务信息,验证状态 + ServicePublishVO service = publishService.getServiceById(serviceId); + if (service == null) { + return OptResult.error("服务不存在"); + } + if (service.getStatus() == ServiceStatus.ONLINE.getCode()) { + return OptResult.error("服务已处于上线状态"); + } + if (service.getStatus() != ServiceStatus.OFFLINE.getCode()) { + return OptResult.error("服务当前状态不支持上线操作"); + } + + // 2. 调用Nacos重新注册服务 + try { + String[] ipArray = service.getIp().split(","); + for (String ip : ipArray) { + String trimmedIp = ip.trim(); + if (!trimmedIp.isEmpty()) { + nacosServiceUtil.registerService( + service.getModelId().toString(), + trimmedIp, + 8080, + service.getApiUrl() + ); + log.info("Nacos服务重新注册成功: {}", trimmedIp); + } + } + } catch (Exception e) { + log.error("Nacos服务注册失败", e); + return OptResult.error("Nacos服务注册失败"); + } + + // 3. 更新数据库状态为“在线” + ServicePublishDTO updateDto = new ServicePublishDTO(); + BeanUtils.copyProperties(service, updateDto); + updateDto.setStatus(ServiceStatus.ONLINE.getCode()); // 假设ONLINE状态码为1 + publishService.updateServiceStatus(updateDto); + + return OptResult.success("服务上线成功"); + } + // 新增:服务下线接口 + @DeleteMapping("/{serviceId}") + @Operation(summary = "下线已发布的服务") + @Transactional + public OptResult offlineService(@PathVariable Long serviceId) { + log.info("下线服务请求: {}", serviceId); + + // 1. 从数据库获取服务信息 + ServicePublishVO service = publishService.getServiceById(serviceId); + if (service == null) { + return OptResult.error("服务不存在"); + } + + // 2. 调用 Nacos 下线服务 + try { + String[] ipArray = service.getIp().split(","); + for (String ip : ipArray) { + String trimmedIp = ip.trim(); + if (!trimmedIp.isEmpty()) { + nacosServiceUtil.deregisterService( + service.getModelId().toString(), + trimmedIp, + 8080 + ); + log.info("Nacos服务下线成功: {}", trimmedIp); + } + } + } catch (Exception e) { + log.error("Nacos服务下线失败", e); + return OptResult.error("Nacos服务下线失败"); + } + + // 3. 修改数据库记录状态为0(下线) + ServicePublishDTO updateDto = new ServicePublishDTO(); + BeanUtils.copyProperties(service, updateDto); + updateDto.setStatus(ServiceStatus.OFFLINE.getCode()); // 假设OFFLINE状态码为0 + publishService.updateServiceStatus(updateDto); + + return OptResult.success("服务下线成功"); + } + + // 新增:服务状态同步接口 + @GetMapping("/sync") + @Operation(summary = "同步服务状态") + public OptResult syncServiceStatus() { + log.info("开始同步服务状态..."); + + try { + // 1. 获取数据库中所有已上线的服务 + List dbServices = publishService.listPublishedServicesByStatus(ServiceStatus.ONLINE.getCode()); + + // 2. 遍历每个服务,检查 Nacos 注册状态 + for (ServicePublishVO service : dbServices) { + String serviceName = service.getModelId().toString(); + String[] ipArray = service.getIp().split(","); + + // 获取 Nacos 中注册的实例 + List nacosInstances = nacosServiceUtil.getServiceInstances(serviceName); + + // 检查每个 IP 是否都在 Nacos 中注册 + for (String ip : ipArray) { + String trimmedIp = ip.trim(); + if (!trimmedIp.isEmpty() && !nacosInstances.contains(trimmedIp)) { + // 如果数据库中有但 Nacos 中没有,则重新注册 + nacosServiceUtil.registerService( + serviceName, + trimmedIp, + 8080, + service.getApiUrl() + ); + log.info("重新注册服务到 Nacos: {}", trimmedIp); + } + } + } + + log.info("服务状态同步完成"); + return OptResult.success("服务状态同步完成"); + } catch (Exception e) { + log.error("服务状态同步失败", e); + return OptResult.error("服务状态同步失败"); + } + } } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/ServiceAPIController.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/ServiceAPIController.java index 468ee08..9332634 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/ServiceAPIController.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/controller/ServiceAPIController.java @@ -1,20 +1,15 @@ package com.bipt.intelligentapplicationorchestrationservice.controller; - import com.bipt.intelligentapplicationorchestrationservice.pojo.OptResult; import com.bipt.intelligentapplicationorchestrationservice.service.ServiceAPIService; import com.bipt.intelligentapplicationorchestrationservice.util.NacosServiceUtil; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; -import io.swagger.v3.oas.models.security.SecurityScheme; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.transaction.annotation.Transactional; -import org.springframework.web.bind.annotation.PathVariable; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.*; import java.util.List; import java.util.Set; @@ -34,12 +29,18 @@ public class ServiceAPIController { @Autowired private RedisTemplate redisTemplate; + @PostMapping("/release") @Operation(summary = "结束访问") @Transactional public OptResult releaseResource(@PathVariable Long modelId) { String key = "modelId:" + modelId; String modelConfig = (String) redisTemplate.opsForValue().get(key); + if (modelConfig == null) { + log.warn("尝试释放不存在的模型资源: {}", modelId); + return OptResult.error("模型资源不存在"); + } + int userMemorySize = parseGpuMemorySize(modelConfig); List instanceIps; try { @@ -48,18 +49,33 @@ public class ServiceAPIController { log.error("获取Nacos实例失败", e); return OptResult.error("获取实例失败"); } - int memorySize; + + boolean released = false; for (String ip : instanceIps) { String ipKey = "ip:" + ip; Integer nowMemorySizeOBJ = (Integer) redisTemplate.opsForValue().get(ipKey); + + // 如果该IP没有记录,则跳过(可能资源分配记录已过期) + if (nowMemorySizeOBJ == null) { + log.warn("IP {} 的资源记录不存在,可能已过期", ip); + continue; + } + int nowMemorySize = nowMemorySizeOBJ; - memorySize = nowMemorySize + userMemorySize; + int newMemorySize = nowMemorySize + userMemorySize; + // 更新IP对应的资源值 - redisTemplate.opsForValue().set(ipKey, memorySize); + redisTemplate.opsForValue().set(ipKey, newMemorySize); // 设置缓存过期时间(3600秒) redisTemplate.expire(ipKey, 3600, TimeUnit.SECONDS); + log.info("IP {} 释放 {} MB 资源,当前可用: {} MB", ip, userMemorySize, newMemorySize); + released = true; } - + + if (!released) { + return OptResult.error("未找到匹配的资源记录"); + } + // 处理等待队列(先来先服务) String waitQueueKey = "waitQueue:" + modelId; // 取出队列头部的任务(最早加入的) @@ -81,11 +97,13 @@ public class ServiceAPIController { // 1. 存储modelConfig到缓存 String modelConfig = serviceAPIService.getByModelId(modelId); int requestMemorySize = parseGpuMemorySize(modelConfig); - if (requestMemorySize == -1){ + if (requestMemorySize == -1) { return OptResult.error("解析配置失败,请检查模型:" + modelId +"是否存在"); } + String modelConfigKey = "modelConfig:" + modelId; redisTemplate.opsForValue().set(modelConfigKey, modelConfig); + // 2. 获取Nacos实例IP列表 List instanceIps; try { @@ -94,8 +112,14 @@ public class ServiceAPIController { log.error("获取Nacos实例失败", e); return OptResult.error("获取实例失败"); } + Set gpuKeys = redisTemplate.keys("gpu:*"); - //根据IP列表查找资源 + if (gpuKeys == null || gpuKeys.isEmpty()) { + log.error("未找到可用的GPU资源"); + return OptResult.error("系统无可用GPU资源"); + } + + // 根据IP列表查找资源 for (String instanceIp : instanceIps) { for (String gpuKey : gpuKeys) { String GPUConfig = (String) redisTemplate.opsForValue().get(gpuKey); @@ -103,7 +127,7 @@ public class ServiceAPIController { // 分割键值对 String[] pairs = GPUConfig.split(","); String ip = null; - int memorySize = 0; + int totalMemorySize = 0; for (String pair : pairs) { String[] keyValue = pair.split(":", 2); if (keyValue.length == 2) { @@ -112,40 +136,62 @@ public class ServiceAPIController { if ("IP".equalsIgnoreCase(key)) { ip = value; } else if ("GPUMemorySize".equalsIgnoreCase(key)) { - memorySize = Integer.parseInt(value); + totalMemorySize = Integer.parseInt(value); } } } + // 检查解析出的 IP 是否在 Nacos 实例列表中 if (instanceIp.equals(ip)) { - log.info("找到 IP {} 对应的 GPU 内存: {} ", ip, memorySize); - if (memorySize>=requestMemorySize){ - int newMemorySize = memorySize - requestMemorySize; - String ipKey = "ip:" + ip; - redisTemplate.opsForValue().set(ipKey,newMemorySize); - //访问请求最大时间为3600s - redisTemplate.expire(ipKey, 3600, TimeUnit.SECONDS); + log.info("找到 IP {} 对应的 GPU 总内存: {} MB", ip, totalMemorySize); + + // 获取当前可用内存 + String ipKey = "ip:" + ip; + Integer currentAvailable = (Integer) redisTemplate.opsForValue().get(ipKey); + + // 如果没有记录,则初始化为总内存 + if (currentAvailable == null) { + currentAvailable = totalMemorySize; + redisTemplate.opsForValue().set(ipKey, currentAvailable); + log.info("IP {} 首次使用,初始可用内存: {} MB", ip, currentAvailable); + } + + // 检查可用内存是否足够 + if (currentAvailable >= requestMemorySize) { + int newMemorySize = currentAvailable - requestMemorySize; + redisTemplate.opsForValue().set(ipKey, newMemorySize); + // 访问请求最大时间为3600s + redisTemplate.expire(ipKey, 3600, TimeUnit.SECONDS); + + // 记录模型与IP的绑定关系 + redisTemplate.opsForValue().set("modelId:" + modelId, modelConfig); + + log.info("IP {} 分配成功,分配前可用: {} MB,分配后可用: {} MB", + ip, currentAvailable, newMemorySize); + return OptResult.success("资源分配成功,使用ip:" + ip); + } else { + log.info("IP {} 资源不足,当前可用: {} MB,请求: {} MB", + ip, currentAvailable, requestMemorySize); } - return OptResult.success("资源分配成功,使用ip:" + ip); - }else { - log.info("资源不足"); } } } } + // 所有实例检查完毕未找到足够资源 String waitQueueKey = "waitQueue:" + modelId; // 改为右插入,保证队列顺序为FIFO(最早的任务在列表头部) - redisTemplate.opsForList().rightPush(waitQueueKey, modelId); + redisTemplate.opsForList().rightPush(waitQueueKey, modelId); log.info("未找到足够资源,任务 {} 加入等待队列", modelId); return OptResult.error("资源不足,等待中"); } + /** * 从模型配置字符串中解析GPU内存需求 * @param modelConfig 模型配置字符串,格式如 "GPUMemorySize:8000,version:1" * @return 解析到的GPU内存大小(MB),若解析失败返回-1 */ - private int parseGpuMemorySize(String modelConfig) { + public int parseGpuMemorySize(String modelConfig) { if (modelConfig == null || modelConfig.isEmpty()) { log.error("模型配置为空,无法解析GPU内存需求"); return -1; @@ -177,5 +223,4 @@ public class ServiceAPIController { } return requestMemorySize; } - -} +} \ No newline at end of file diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/enumeration/ServiceStatus.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/enumeration/ServiceStatus.java new file mode 100644 index 0000000..143d9de --- /dev/null +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/enumeration/ServiceStatus.java @@ -0,0 +1,31 @@ +package com.bipt.intelligentapplicationorchestrationservice.enumeration; + +public enum ServiceStatus { + OFFLINE(0, "下线"), + ONLINE(1, "上线"); + + private final int code; + private final String description; + + ServiceStatus(int code, String description) { + this.code = code; + this.description = description; + } + + public int getCode() { + return code; + } + + public String getDescription() { + return description; + } + + public static ServiceStatus fromCode(int code) { + for (ServiceStatus status : ServiceStatus.values()) { + if (status.code == code) { + return status; + } + } + throw new IllegalArgumentException("未知的状态码: " + code); + } +} diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/ModelMapper.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/ModelMapper.java index 1d8847f..7228935 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/ModelMapper.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/ModelMapper.java @@ -85,6 +85,10 @@ public interface ModelMapper { */ ModelTrainInfoVO getModelTrainInfo(Long id); - + /** + * 获取模型版本信息 + * @param modelId + * @return + */ ModelVersion selectByModelId(Long modelId); } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/PublishMapper.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/PublishMapper.java index 5bb4c45..5c70765 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/PublishMapper.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/mapper/PublishMapper.java @@ -1,10 +1,13 @@ package com.bipt.intelligentapplicationorchestrationservice.mapper; import com.bipt.intelligentapplicationorchestrationservice.entity.ModelSelectVO; +import com.bipt.intelligentapplicationorchestrationservice.pojo.ModelVersion; import com.bipt.intelligentapplicationorchestrationservice.pojo.ServicePublishDTO; import com.bipt.intelligentapplicationorchestrationservice.pojo.ServicePublishVO; +import org.apache.ibatis.annotations.Delete; import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Select; +import org.apache.ibatis.annotations.Update; import java.util.List; @@ -17,7 +20,7 @@ public interface PublishMapper { Long getByApiUrl(String apiUrl); - @Select("SELECT model_id,api_url,ip FROM service_publish") + @Select("SELECT * FROM service_publish") List listPublishedServices(); @Select("SELECT " + "mv.model_id AS modelId, " + @@ -26,4 +29,37 @@ public interface PublishMapper { "FROM model_version mv " + "LEFT JOIN model_info m ON mv.model_id = m.id") List selectModelSelectList(); + + // 根据ID查询服务(移除update_time和deleted字段) + @Select("SELECT id, model_id, api_url, ip, create_time " + + "FROM service_publish WHERE id = #{serviceId}") + ServicePublishVO getServiceById(Long serviceId); + + + void updateStatus(Long id, int status); + + List selectByStatus(Integer status); + @Select("SELECT " + + "mv.id AS modelId, " + + /*"mv.model_id AS modelId, " +*/ + "m.model_name AS modelName, " + + "mv.version AS version " + + "FROM model_version mv " + + "LEFT JOIN model_info m ON mv.model_id = m.id " + + "WHERE mv.model_id NOT IN ( " + + " SELECT DISTINCT model_id " + + " FROM service_publish " + + " WHERE status = #{code} " + + ")") + List selectModelNamesByStatus(int code); + @Select("select model_id from model_version where id=#{id}") + Long getByMdVersionId(Long id); + + /** + * 根据modelversionId查询Modelversion信息 + * @param id + * @return + */ + @Select("select * from model_version where id = #{id}") + ModelVersion selectByModelVersionId(Long id); } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/AlgorithmInfo.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/AlgorithmInfo.java index 5fe9302..53c2804 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/AlgorithmInfo.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/AlgorithmInfo.java @@ -84,4 +84,8 @@ public class AlgorithmInfo { public void setFileSize(Long fileSize) { this.fileSize = fileSize; } -} \ No newline at end of file + + public String getFilePath() { + return algorithmFile; + } +} \ No newline at end of file diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishDTO.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishDTO.java index 021f1f8..c5c79e5 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishDTO.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishDTO.java @@ -22,4 +22,5 @@ public class ServicePublishDTO implements Serializable { @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8") private LocalDateTime createTime; private String ip; + private int status; } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishVO.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishVO.java index c08bc6a..a407236 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishVO.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/pojo/ServicePublishVO.java @@ -16,11 +16,13 @@ import java.time.LocalDateTime; @NoArgsConstructor @AllArgsConstructor public class ServicePublishVO implements Serializable { + private Long id; private Long modelId; /*private String GPUModel;*/ private String ip; /* private String GPUMemorySize;*/ private String apiUrl; + private int status; } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/AlgorithmInfoService.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/AlgorithmInfoService.java index 5b4a709..10116e0 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/AlgorithmInfoService.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/AlgorithmInfoService.java @@ -3,6 +3,7 @@ package com.bipt.intelligentapplicationorchestrationservice.service; import com.bipt.intelligentapplicationorchestrationservice.pojo.AlgorithmInfo; import org.springframework.web.multipart.MultipartFile; +import java.io.IOException; import java.util.List; public interface AlgorithmInfoService { @@ -15,7 +16,7 @@ public interface AlgorithmInfoService { void save(AlgorithmInfo algorithmInfo, MultipartFile file); - String run(Long id, String param); + String run(String scriptPath, List args) throws IOException, InterruptedException; List getAllNames(); diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/AlgorithmInfoServiceImpl.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/AlgorithmInfoServiceImpl.java index 0fd8dca..f631492 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/AlgorithmInfoServiceImpl.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/AlgorithmInfoServiceImpl.java @@ -12,9 +12,17 @@ import org.springframework.util.StringUtils; import org.springframework.web.multipart.MultipartFile; import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.List; +import java.util.UUID; @Service @Slf4j @@ -80,57 +88,98 @@ public class AlgorithmInfoServiceImpl implements AlgorithmInfoService { throw new RuntimeException("算法已存在,请去修改算法"); } - // 只接收文件但不进行保存操作 if (file != null && !file.isEmpty()) { - log.info("已接收文件: {}", file.getOriginalFilename()); - log.info("文件大小: {} 字节", file.getSize()); - log.info("文件类型: {}", file.getContentType()); - // 临时设置一个空路径(避免数据库保存空值) - //todo 保存到分布式存储 - algorithmInfo.setAlgorithmFile(""); + try { + // 获取文件原始名称 + String originalFilename = file.getOriginalFilename(); + if (originalFilename == null || originalFilename.isEmpty()) { + throw new RuntimeException("文件名称为空"); + } + + // 生成唯一文件名,避免冲突 + String fileName = UUID.randomUUID().toString() + "_" + originalFilename; + + // 关键修改:使用实际存在的绝对路径(替换为你的实际路径,如 D:/algorithm_files) + // 建议在配置文件中配置,而非硬编码 + String uploadDir = "D:/algorithm_files"; // 例如:Windows 路径用 D:/xxx,Linux 用 /home/xxx + + // 构建路径对象(使用 Path 而非 File,更适合跨平台) + Path saveDirPath = Paths.get(uploadDir); + + // 确保目录存在(createDirectories 会创建所有不存在的父目录,跨平台兼容) + if (!Files.exists(saveDirPath)) { + Files.createDirectories(saveDirPath); // 关键:创建多级目录 + log.info("已创建存储目录: {}", saveDirPath.toAbsolutePath()); + } + + // 完整文件路径 + Path saveFilePath = saveDirPath.resolve(fileName); + + // 保存文件到指定路径 + file.transferTo(saveFilePath); // 使用 Path 重载方法,更可靠 + + // 设置文件路径到实体类(存储绝对路径或可访问的相对路径) + algorithmInfo.setAlgorithmFile(saveFilePath.toString()); + + // 设置文件大小 + algorithmInfo.setFileSize(Files.size(saveFilePath)); + + log.info("文件保存成功: {}", saveFilePath.toAbsolutePath()); + } catch (Exception e) { + log.error("文件保存失败", e); + throw new RuntimeException("文件保存失败: " + e.getMessage(), e); + } + } else { + // 文件为空的处理逻辑 + algorithmInfo.setAlgorithmFile(null); + algorithmInfo.setFileSize(0L); // 空文件大小设为0 } algorithmInfo.setCreateTime(LocalDateTime.now()); - // 保存算法信息到数据库(注意:此时algorithmFile字段为空) + // 保存算法信息到数据库 algorithmInfoMapper.insert(algorithmInfo); } - @Override - public String run(Long id, String param) { - //todo从分布式存储中拿到文件(以下是示例) - String file = algorithmInfoMapper.getFileById(id); - StringBuilder result = new StringBuilder(); // 用于存储结果 + /** + * 执行Python算法脚本并返回结果 + * @param scriptPath Python脚本路径 + * @param args 命令行参数列表 + * @return 脚本执行结果 + */ + public String run(String scriptPath, List args) throws IOException, InterruptedException { + // 构建命令:python [脚本路径] [参数1] [参数2] ... + List command = new ArrayList<>(); + command.add("python"); // Python解释器路径,可配置在application.properties中 + command.add(scriptPath); // 脚本路径 + command.addAll(args); // 添加所有参数 - try { - // 构建命令,将 param 作为参数传递给 Python 脚本 - ProcessBuilder pb = new ProcessBuilder("python", file, param); - Process process = pb.start(); + // 打印完整命令(用于调试) + log.info("执行命令: {}", String.join(" ", command)); - // 读取标准输出(脚本执行结果) - BufferedReader reader = new BufferedReader( - new InputStreamReader(process.getInputStream())); + // 创建进程并执行命令 + ProcessBuilder processBuilder = new ProcessBuilder(command); + processBuilder.redirectErrorStream(true); // 将错误输出合并到标准输出 + Process process = processBuilder.start(); + + // 读取脚本输出(使用UTF-8编码,避免中文乱码) + StringBuilder output = new StringBuilder(); + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { - result.append(line).append("\n"); + output.append(line).append("\n"); } - - // 读取错误输出 - BufferedReader errorReader = new BufferedReader( - new InputStreamReader(process.getErrorStream())); - String errorLine; - while ((errorLine = errorReader.readLine()) != null) { - result.append("Error: ").append(errorLine).append("\n"); - } - - int exitCode = process.waitFor(); - result.append("Exit Code: ").append(exitCode); - - } catch (Exception e) { - result.append("执行异常: ").append(e.getMessage()); - e.printStackTrace(); } - return result.toString(); // 返回完整结果 + // 等待进程执行完成并获取退出码 + int exitCode = process.waitFor(); + + // 检查脚本是否成功执行 + if (exitCode != 0) { + throw new RuntimeException("脚本执行失败,退出码: " + exitCode); + } + + return output.toString(); } @Override diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/PublishServiceImpl.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/PublishServiceImpl.java index 82815b8..426dcd6 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/PublishServiceImpl.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/Impl/PublishServiceImpl.java @@ -1,6 +1,7 @@ package com.bipt.intelligentapplicationorchestrationservice.service.Impl; import com.bipt.intelligentapplicationorchestrationservice.entity.ModelSelectVO; +import com.bipt.intelligentapplicationorchestrationservice.enumeration.ServiceStatus; import com.bipt.intelligentapplicationorchestrationservice.mapper.PublishMapper; import com.bipt.intelligentapplicationorchestrationservice.pojo.ServicePublishDTO; import com.bipt.intelligentapplicationorchestrationservice.pojo.ServicePublishVO; @@ -10,7 +11,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.util.ArrayList; import java.util.List; /** @@ -48,5 +48,30 @@ public class PublishServiceImpl implements PublishService { return publishMapper.selectModelSelectList(); } + @Override + public ServicePublishVO getServiceById(Long serviceId) { + return publishMapper.getServiceById(serviceId); + } + + @Override + public void updateServiceStatus(ServicePublishDTO servicePublishDTO) { + publishMapper.updateStatus(servicePublishDTO.getId(), servicePublishDTO.getStatus()); + } + + @Override + public List listPublishedServicesByStatus(Integer status) { + return publishMapper.selectByStatus(status); + } + + @Override + public List getOnlineModelNames() { + // 调用Mapper查询状态为“在线”的模型(ServiceStatus.ONLINE.getCode() 假设为1) + return publishMapper.selectModelNamesByStatus(ServiceStatus.OFFLINE.getCode()); + } + + @Override + public Long getModelId(Long id) { + return publishMapper.getByMdVersionId(id); + } } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/PublishService.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/PublishService.java index d0a8326..61572a4 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/PublishService.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/service/PublishService.java @@ -10,8 +10,17 @@ public interface PublishService { void save(ServicePublishDTO servicePublishDTO); - List listPublishedServices(); List getModelNames(); + + ServicePublishVO getServiceById(Long serviceId); + + void updateServiceStatus(ServicePublishDTO updateDto); + + List listPublishedServicesByStatus(Integer status); + + List getOnlineModelNames(); + + Long getModelId(Long id); } diff --git a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/util/NacosServiceUtil.java b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/util/NacosServiceUtil.java index bc1a665..7bd9d72 100644 --- a/src/main/java/com/bipt/intelligentapplicationorchestrationservice/util/NacosServiceUtil.java +++ b/src/main/java/com/bipt/intelligentapplicationorchestrationservice/util/NacosServiceUtil.java @@ -1,14 +1,17 @@ package com.bipt.intelligentapplicationorchestrationservice.util; -import com.alibaba.nacos.api.naming.NamingFactory; +import com.alibaba.nacos.api.NacosFactory; +import com.alibaba.nacos.api.exception.NacosException; import com.alibaba.nacos.api.naming.NamingService; import com.alibaba.nacos.api.naming.pojo.Instance; +import com.alibaba.nacos.api.naming.pojo.ServiceInfo; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; @Component @@ -17,26 +20,102 @@ public class NacosServiceUtil { @Value("${spring.cloud.nacos.discovery.server-addr}") private String nacosServerAddr; - public void registerService(String serviceName, String ip, int port, String url) throws Exception { // 新增url参数 - NamingService naming = NamingFactory.createNamingService(nacosServerAddr); + private NamingService namingService; + + /** + * 获取NamingService实例(线程安全) + */ + private NamingService getNamingService() throws Exception { + if (namingService == null) { + synchronized (this) { + if (namingService == null) { + namingService = NacosFactory.createNamingService(nacosServerAddr); + } + } + } + return namingService; + } + + /** + * 注册服务到Nacos + */ + public void registerService(String serviceName, String ip, int port, String url) throws Exception { + NamingService naming = getNamingService(); Instance instance = new Instance(); instance.setIp(ip); instance.setPort(port); - // 添加元数据存储URL + instance.setWeight(1.0); + instance.setHealthy(true); + + // 添加元数据 Map metadata = new HashMap<>(); - metadata.put("url", url); // 将URL存入元数据 + metadata.put("url", url); + metadata.put("registerTime", String.valueOf(System.currentTimeMillis())); instance.setMetadata(metadata); + naming.registerInstance(serviceName, instance); } + /** + * 从Nacos注销服务 + */ + public void deregisterService(String serviceName, String ip, int port) throws Exception { + NamingService naming = getNamingService(); + naming.deregisterInstance(serviceName, ip, port); + } + + /** + * 获取服务所有实例 + */ + public List getAllInstances(String serviceName) throws Exception { + NamingService naming = getNamingService(); + return naming.getAllInstances(serviceName); + } + /** * 获取服务所有实例IP */ public List getServiceInstances(String serviceName) throws Exception { - NamingService naming = NamingFactory.createNamingService(nacosServerAddr); - List instances = naming.getAllInstances(serviceName); - return instances.stream() + return getAllInstances(serviceName).stream() .map(Instance::getIp) .collect(Collectors.toList()); } + + /** + * 获取服务信息(适配Nacos 2.x) + */ + public ServiceInfo getServiceInfo(String serviceName) throws Exception { + NamingService naming = getNamingService(); + // 使用selectInstances替代getServiceInfo + List instances = naming.selectInstances(serviceName, true); + + ServiceInfo serviceInfo = new ServiceInfo(); + serviceInfo.setName(serviceName); + serviceInfo.setHosts(instances); + return serviceInfo; + } + + /** + * 根据IP和端口查询实例是否存在 + */ + public boolean isInstanceExists(String serviceName, String ip, int port) throws Exception { + List instances = getAllInstances(serviceName); + return instances.stream() + .anyMatch(instance -> + Objects.equals(instance.getIp(), ip) && + instance.getPort() == port + ); + } + + /** + * 更新服务实例元数据 + */ + public void updateInstanceMetadata(String serviceName, String ip, int port, Map metadata) throws Exception { + NamingService naming = getNamingService(); + Instance instance = new Instance(); + instance.setIp(ip); + instance.setPort(port); + instance.setMetadata(metadata); + naming.registerInstance(serviceName, instance); + } } \ No newline at end of file diff --git a/src/main/resources/mapper/PublishMapper.xml b/src/main/resources/mapper/PublishMapper.xml index 49065fa..f2edf81 100644 --- a/src/main/resources/mapper/PublishMapper.xml +++ b/src/main/resources/mapper/PublishMapper.xml @@ -3,9 +3,14 @@ INSERT INTO service_publish - (id,model_id,api_url,create_time,ip) - values (#{id}, #{modelId}, #{apiUrl}, #{createTime},#{ip}) + (id,model_id,api_url,create_time,ip,status) + values (#{id}, #{modelId}, #{apiUrl}, #{createTime},#{ip},#{status}) + + UPDATE service_publish + SET status = #{status} + WHERE id = #{id} + + \ No newline at end of file diff --git a/src/test/java/com/bipt/intelligentapplicationorchestrationservice/ServiceAPIControllerTest.java b/src/test/java/com/bipt/intelligentapplicationorchestrationservice/ServiceAPIControllerTest.java new file mode 100644 index 0000000..b52ee8e --- /dev/null +++ b/src/test/java/com/bipt/intelligentapplicationorchestrationservice/ServiceAPIControllerTest.java @@ -0,0 +1,125 @@ +package com.bipt.intelligentapplicationorchestrationservice; + +import com.bipt.intelligentapplicationorchestrationservice.controller.ServiceAPIController; +import com.bipt.intelligentapplicationorchestrationservice.pojo.OptResult; +import com.bipt.intelligentapplicationorchestrationservice.service.ServiceAPIService; +import com.bipt.intelligentapplicationorchestrationservice.util.NacosServiceUtil; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.data.redis.core.ListOperations; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.ValueOperations; +import org.springframework.test.context.junit4.SpringRunner; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +@RunWith(SpringRunner.class) +public class ServiceAPIControllerTest { + + @Mock + private ServiceAPIService serviceAPIService; + + @Mock + private NacosServiceUtil nacosServiceUtil; + + @Mock + private RedisTemplate redisTemplate; + + @Mock + private ValueOperations valueOperations; + + @Mock + private ListOperations listOperations; + + @InjectMocks + private ServiceAPIController serviceAPIController; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + when(redisTemplate.opsForValue()).thenReturn(valueOperations); + when(redisTemplate.opsForList()).thenReturn(listOperations); + } + + + @Test + public void testMultiResourceAllocation() throws Exception { + System.out.println("===== 开始测试多资源分配 ====="); + + // 环境初始化 + String ip1 = "192.168.1.1"; + String ip2 = "192.168.1.2"; + List instanceIps = Arrays.asList(ip1, ip2); + Set gpuKeys = new HashSet<>(Arrays.asList("gpu:1", "gpu:2")); + + // 模拟两个GPU的总内存配置 + when(valueOperations.get("gpu:1")).thenReturn("IP:" + ip1 + ",GPUMemorySize:8000"); + when(valueOperations.get("gpu:2")).thenReturn("IP:" + ip2 + ",GPUMemorySize:10000"); + + // 第一个请求(分配到IP1,需要3000MB) + System.out.println("\n=== 第一个请求:分配到IP1 ==="); + Long modelId1 = 1L; + String modelConfig1 = "GPUMemorySize:3000,version:1"; + + when(serviceAPIService.getByModelId(modelId1)).thenReturn(modelConfig1); + when(nacosServiceUtil.getServiceInstances(modelId1.toString())).thenReturn(instanceIps); + when(redisTemplate.keys("gpu:*")).thenReturn(gpuKeys); + // IP1首次使用,无需提前设置ip:ip1(默认用总内存8000) + + OptResult result1 = serviceAPIController.schedule(modelId1); + + // 验证结果 + assertTrue("第一个请求应成功", result1.isSuccess()); + assertEquals("资源分配成功,使用ip:" + ip1, result1.getData()); + verify(valueOperations, times(1)).set("ip:" + ip1, 5000); // 8000-3000 + System.out.println("IP1 可用内存=5000MB, IP2 可用内存=10000MB(初始)"); + + // 第二个请求(分配到IP2,需要6000MB) + System.out.println("\n=== 第二个请求:分配到IP2 ==="); + Long modelId2 = 2L; + String modelConfig2 = "GPUMemorySize:6000,version:1"; + + when(serviceAPIService.getByModelId(modelId2)).thenReturn(modelConfig2); + when(nacosServiceUtil.getServiceInstances(modelId2.toString())).thenReturn(instanceIps); + when(valueOperations.get("ip:" + ip1)).thenReturn(5000); // IP1当前可用5000(不足6000) + // IP2首次使用,无需提前设置ip:ip2(默认用总内存10000) + + OptResult result2 = serviceAPIController.schedule(modelId2); + + // 验证结果 + assertTrue("第二个请求应成功", result2.isSuccess()); + assertEquals("资源分配成功,使用ip:" + ip2, result2.getData()); + verify(valueOperations, times(1)).set("ip:" + ip2, 4000); // 10000-6000 + System.out.println("IP1 可用内存=5000MB, IP2 可用内存=4000MB"); + + // 第三个请求(资源不足) + System.out.println("\n=== 第三个请求:资源不足 ==="); + Long modelId3 = 3L; + String modelConfig3 = "GPUMemorySize:7000,version:1"; + + when(serviceAPIService.getByModelId(modelId3)).thenReturn(modelConfig3); + when(valueOperations.get("ip:" + ip1)).thenReturn(5000); // IP1可用5000 <7000 + when(valueOperations.get("ip:" + ip2)).thenReturn(4000); // IP2可用4000 <7000 + + OptResult result3 = serviceAPIController.schedule(modelId3); + + // 验证结果 + assertFalse("第三个请求应失败", result3.isSuccess()); + assertEquals("资源不足,等待中", result3.getErrorInfo()); + verify(listOperations, times(1)).rightPush("waitQueue:" + modelId3, modelId3); + System.out.println("模型ID=" + modelId3 + " 加入等待队列"); + + System.out.println("===== 多资源分配测试完成 ====="); + } +} \ No newline at end of file