- 积分
- 517
- 明经币
- 个
- 注册时间
- 2024-2-29
- 在线时间
- 小时
- 威望
-
- 金钱
- 个
- 贡献
-
- 激情
-
|
楼主 |
发表于 2024-3-29 12:47:39
|
显示全部楼层
源代码:
- using Autodesk.AutoCAD.DatabaseServices;
- using Autodesk.AutoCAD.Geometry;
- using Microsoft.ML;
- using Microsoft.ML.Data;
- using Microsoft.ML.Trainers;
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Threading.Tasks;
- namespace Net8Cad
- {
- public class MLTest
- {
- [CommandMethod("mssvm")]
- public void Run()
- {
- Document doc = Application.DocumentManager.MdiActiveDocument;
- Editor editor = doc.Editor;
- Database db = HostApplicationServices.WorkingDatabase;
- try
- {
- // 数据集的获取
- var circles = GetAllT_viaSelection<Circle>(doc);
- List<PointData> pointsData = new List<PointData>();
- foreach ( var circle in circles )
- {
- PointData pd=new PointData();
- pd.PointX = (float)circle.Center.X;
- pd.PointY = (float)circle.Center.Y;
- pd.PointZ = (float)circle.Center.Z;
- if (circle.Color.ColorIndex == 1)
- {
- pd.PointLabel = true;
- }
- else
- {
- pd.PointLabel = false;
- }
- pointsData.Add(pd);
- }
- // 构建机器学习的训练模型
- var mlContext = new MLContext();
- var trainingDataView = mlContext.Data.LoadFromEnumerable<PointData>(pointsData);
- var trainer=mlContext.BinaryClassification.Trainers.LinearSvm(
- labelColumnName:"Label",
- featureColumnName: "Features",
- numberOfIterations:100);
- var trainingPipeline = mlContext.Transforms.Concatenate(
- outputColumnName: "NumFeatures",
- nameof(PointData.PointX),
- nameof(PointData.PointY),
- nameof(PointData.PointZ))
-
- .Append(mlContext.Transforms.Concatenate(outputColumnName: "Features", "NumFeatures"))
- .Append(mlContext.Transforms.CopyColumns(outputColumnName: "Label",
- inputColumnName: nameof(PointData.PointLabel)))
- .Append(trainer);
- var model = trainingPipeline.Fit(trainingDataView);
- var svmModel = model.LastTransformer.Model;
- var weights = svmModel.Weights; // w1*x+w2*y+b=0
- var bias=svmModel.Bias;
- WriteMessage(doc, "训练模型为:" + $"{model}");
- WriteMessage(doc, $"weights 个数:{weights.Count} bias 值为:{bias}");
- WriteMessage(doc, $"第一个:{weights[0]} 第二个{weights[1]} 第三个{weights[2]}");
- // w1*x+w2*y+b=0
- double x1 = 100;
- double x2 = 200;
- double y1 = (weights[0] * x1 + bias) / (-weights[1]);
- double y2 = (weights[0] * x2 + bias) / (-weights[1]);
- Xline xl = new Xline();
- xl.BasePoint = new Point3d(x1, y1, 0);
- xl.SecondPoint=new Point3d(x2, y2, 0);
- ToModelSpace(doc, xl);
- WriteMessage(doc, "done");
- }
- catch (System.Exception ex)
- {
- string msg = "报错为:" + ex.Message + "\n" + "位置为:" + ex.StackTrace;
- WriteMessage(doc, msg);
- }
- }
- List<T> GetAllT_viaSelection<T>(Document doc) where T : Entity
- {
- List<T> selectedLines = new List<T>();
- using (Transaction tr = doc.Database.TransactionManager.StartTransaction())
- {
- // 设置提示词
- PromptSelectionOptions opts = new PromptSelectionOptions();
- opts.MessageForAdding = "\n请框选:\n";
- // 提示用户框选实体
- PromptSelectionResult result = doc.Editor.GetSelection(opts);
- if (result.Status != PromptStatus.OK)
- return null;
- // 获取选中实体的ObjectId数组
- ObjectId[] objectIds = result.Value.GetObjectIds();
- // 遍历选中实体
- foreach (ObjectId objectId in objectIds)
- {
- // 通过ObjectId打开实体
- Entity entity = (Entity)objectId.GetObject(OpenMode.ForRead, false);
- // 判断实体是否为Line类型
- if (entity.GetType() == typeof(T))
- {
- T line = (T)entity;
- selectedLines.Add(line);
- }
- }
- tr.Commit();
- }
- return selectedLines;
- }
- void WriteMessage(Document doc, string message)
- {
- Editor editor = doc.Editor;
- editor.WriteMessage("\n" + message);
- }
- ObjectId ToModelSpace(Document doc, Entity entity)
- {
- ObjectId objectId;
- using (doc.LockDocument())
- {
- Database database = doc.Database;
- using (Transaction trans = database.TransactionManager.StartTransaction())
- {
- BlockTable blockTable = (BlockTable)trans.GetObject(database.BlockTableId, OpenMode.ForWrite, false);
- BlockTableRecord blockTableRecord = (BlockTableRecord)trans.GetObject(blockTable[BlockTableRecord.ModelSpace], OpenMode.ForWrite, false);
- objectId = blockTableRecord.AppendEntity(entity);
- trans.AddNewlyCreatedDBObject(entity, true);
- trans.Commit();
- }
- }
- return objectId;
- }
- }
- public class PointData
- {
- [LoadColumn(0)]
- public bool PointLabel { get; set; }
- [LoadColumn(1)]
- public float PointX { get; set; }
- [LoadColumn(2)]
- public float PointY { get; set; }
- [LoadColumn(3)]
- public float PointZ { get; set; }
- }
- }
|
|