枫叶棋语 发表于 2023-6-3 18:20:34

Pycad 四叉树配合A*算法寻路

本帖最后由 枫叶棋语 于 2023-6-3 18:29 编辑


from Autodesk.AutoCAD.Runtime import *
from Autodesk.AutoCAD.ApplicationServices import *
from Autodesk.AutoCAD.EditorInput import *
from Autodesk.AutoCAD.DatabaseServices import *
from Autodesk.AutoCAD.Geometry import *
from System.Linq import Enumerable
from pycad.runtime import *
from pycad.system import *
from pycad.runtime.edx import *
from math import *
import re as re
from collections import defaultdict
from heapq import *
from Queue import PriorityQueue
from math import *
clr.ImportExtensions(System.Linq)

def ColsestPointByPt(pt,pts):
    return pts.OrderBy(lambda x: x.DistanceTo(pt)).First()

def round_point(pt):
      return Point3d(round(pt.X,5),round(pt.Y,5),0)
def LoopallChildren(parent):
    for child in parent.children:
      if child.children:
            for subchild in LoopallChildren(child):
                yield subchild
      yield child         
#创建路径点对类
class Path:
    def __init__(self,pt1,pt2):
      self._pt1=pt1
      self._pt2=pt2
    @property
    def Length(self):
      return abs(self._pt1-self._pt2)+abs(self._pt1-self._pt2)

#创建四叉树节点
class QuadNode:
    def __init__(self, line):
      self._pt1 = line.StartPoint
      self._pt2 = line.EndPoint
      self._pair_tuple = (self._pt1, self._pt2)
      self._pair_set = {self._pt1, self._pt2}
      self.line = line
      self.bounds = self.line.GeometricExtents
      self.route_nodes = {round_point(self.line.StartPoint), round_point(self.line.EndPoint)}
    def __eq__(self, other):
      return self.bounds==other.bounds and self.line==other.line
    def __hash__(self):
      return hash(tuple(self._pair_set))
    def __contains__(self, pt):
      return pt in self._pair_set
    @property
    def paths(self):
      paths = set()
      pts = self.route_nodes
      pts = list(pts.OrderBy(lambda pt: pt.X).ThenBy(lambda pt: pt.Y))
      for i in range(len(pts) - 1):
            paths.add(Path(pts, pts))
      return paths



#创建四叉树
class QuadTree(object):
    def __init__(self,bounds,max_items, max_depth, depth=0):
      self.nodes = set()
      self.children = []
      self.bounds=bounds
      self.MaxPoint=bounds.MaxPoint
      self.MinPoint=bounds.MinPoint
      self.Center =self.MinPoint+(self.MaxPoint-self.MinPoint)/2
      self.max_items = max_items
      self.max_depth = max_depth
      self.depth = depth
    def GetLastTree(self):
      if self.children==[] and self.nodes != set():
            n +=1
            treeset.add(self.bounds)
      else:
            for child in self.children:
                child.GetLastTree()
      print(n)
      return treeset

    def get_generator(self):
      for child in LoopallChildren(self):
            yield child
    def insert(self, node):
      if len(self.children) == 0:
            self.nodes.add(node)
            if len(self.nodes) > self.max_items and self.depth < self.max_depth:
                self.Split()
      else:
            self.insert_into_children(node)

    def remove(self, ent):
      if len(self.children) == 0:
            node = QuadNode(ent)
            self.nodes.remove(node)
      else:
            self.remove_from_children(node)
            
    def intersect(self, bounds=Extents3d, results=None, uniq=None):   
      MinPoint=bounds.MinPoint
      MaxPoint=bounds.MaxPoint
      if results is None:
            results = set()
            uniq = set()
      if self.children:
            if MinPoint <= self.Center:
                if MinPoint <= self.Center:
                  self.children.intersect(bounds, results, uniq)
                if MaxPoint >= self.Center:
                  self.children.intersect(bounds, results, uniq)
            if MaxPoint >= self.Center:
                if MinPoint <= self.Center:
                  self.children.intersect(bounds, results, uniq)
                if MaxPoint >= self.Center:
                  self.children.intersect(bounds, results, uniq)
      for node in self.nodes:
            NodeId = id(node)
            if (not (NodeId in uniq) and
                node.bounds.MaxPoint >= MinPoint and node.bounds.MinPoint <= MaxPoint and
                node.bounds.MaxPoint >= MinPoint and node.bounds.MinPoint <= MaxPoint):
                results.add(node)
                uniq.add(NodeId)
      return results

    def intersectwith(self, node):
      bounds=node.bounds
      bounds.ExpandBy(Vector3d(0.01,0.01,0))
      bounds.ExpandBy(Vector3d(-0.01,-0.01,0))
      results=self.intersect(bounds)
      for x in results:
            pts=QTIntersectWith(node.line,x.line)
            if pts != None:
                for x in pts:
                  node.route_nodes.add(round_point(x))


    def insert_into_children(self, node):
      bounds= node.bounds
      MinPoint=bounds.MinPoint
      MaxPoint=bounds.MaxPoint
      if (MinPoint<= self.Center and MaxPoint >= self.Center and
            MinPoint <= self.Center and MaxPoint >= self.Center):

            self.nodes.add(node)
      else:
            if MinPoint<= self.Center:
                if MinPoint <= self.Center:
                  self.children.insert(node)
                if MaxPoint >= self.Center:
                  self.children.insert(node)
            if MaxPoint > self.Center:
                if MinPoint <= self.Center:
                  self.children.insert(node)
                if MaxPoint >= self.Center:
                  self.children.insert(node)

    def remove_from_children(self, node):
      bounds=node.bounds
      MinPoint=bounds.MinPoint
      MaxPoint=bounds.MaxPoint
      if (MinPoint <= self.Center and MaxPoint >= self.Center and
            MinPoint <= self.Center and MaxPoint >= self.Center):

            self.nodes.remove(node)
      else:
            if MinPoint<= self.Center:
                if MinPoint <= self.Center:
                  self.children.remove(node)
                if MaxPoint >= self.Center:
                  self.children.remove(node)
            if MaxPoint > self.Center:
                if MinPoint <= self.Center:
                  self.children.remove(node)
                if MaxPoint >= self.Center:
                  self.children.remove(node)

    def Split(self):
      new_depth = self.depth + 1
      self.children = [QuadTree(Extents3d(self.MinPoint,self.Center),#左下
                                 self.max_items, self.max_depth, new_depth),
                         QuadTree(Extents3d(Point3d(self.MinPoint,self.Center,0),Point3d(self.Center,self.MaxPoint,0)),#左上
                                 self.max_items, self.max_depth, new_depth),
                         QuadTree(Extents3d(Point3d(self.Center,self.MinPoint,0),Point3d(self.MaxPoint,self.Center,0)),#右下
                                 self.max_items, self.max_depth, new_depth),
                         QuadTree(Extents3d(self.Center,self.MaxPoint),#右上
                                 self.max_items, self.max_depth, new_depth)]
      nodes = self.nodes
      self.nodes = set()
      for node in nodes:
            self.insert_into_children(node)

    def __len__(self):
      size = 0
      for child in self.children:
            size += len(child)
      size += len(self.nodes)
      return size

#全局函数,四叉树求交点
def QTIntersectWith(line1,line2):
    points=Point3dCollection()
    if line1==None:
      return None
    else:
      line1.IntersectWith(line2,Intersect.OnBothOperands,Plane(), points, 0, 0)
      if points.Count==0:
            return None
      else:
            return set(points)
#创建地图类
class Graphs:
    def __init__(self, paths):
      self.graph = defaultdict(set)
      for path in paths:
            self.graph.add((path._pt2, path.Length))
            self.graph.add((path._pt1, path.Length))

    @staticmethod
    def get_distance(pt1, pt2):
      return abs(pt1 - pt2) + abs(pt1 - pt2)

    @staticmethod
    def num_turns(prev, next, end):
      if prev is None:
            return 0
      elif (prev - next) * (next - end) == (prev - next) * (next - end):
            return 0
      else:
            return 1

    def a_star_search(self, start, end):
      frontier = []
      heappush(frontier, (0, start))
      came_from = {}
      cost_so_far = {}
      num_turns_so_far = {}
      came_from = None
      cost_so_far = 0
      num_turns_so_far = 0
      distance = self.get_distance(start, end)
      while frontier:
            current = heappop(frontier)
            if current == end:
                break
            for next, cost in self.graph:
                new_cost = cost_so_far + cost
                num_turns_to_next = num_turns_so_far + self.num_turns(current, next, end)
                priority = (new_cost + num_turns_to_next) * (1 + 1 / distance)
                if next not in cost_so_far or priority < cost_so_far / distance:
                  cost_so_far = new_cost
                  num_turns_so_far = num_turns_to_next
                  heappush(frontier, (priority, next))
                  came_from = current               
      path =
      while path[-1] != start:
            path.append(came_from])
      return list(reversed(path))

#初步封装
def MakeGraph(lines):
    nodes = set(map(lambda x: QuadNode(x),lines))
    bounds =Extents3d()
    for node in nodes:
      bounds.AddExtents(node.bounds)
      qtree =QuadTree(bounds,40,10,0)
    for node in nodes:
      qtree.insert(node)
    paths =set()
    for node in nodes:
      qtree.intersectwith(node)
      paths |= node.paths
      graph= Graphs(paths)
    pts = set()
    for node in nodes:
      pts |= node.route_nodes
    return pts,graph





##################以下为主程序

pts=None
graph = None


@command()
def CreatGraph(doc):
    global pts,graph
    with dbtrans(doc) as tr:
      res= ssget_x((conv.And,(0,"line,*polyline*"),(8,r'CableTray-Center')))
      if not res.ok() : return
      ids =tuple(res)
      lines = map(lambda x: x.GetObject(OpenMode.ForWrite),ids)
      linesset =set()
      for line in lines:
            if isinstance(line, (Polyline,Polyline3d,Polyline2d)):
                EntitySet=DBObjectCollection()
                Entity.Explode(line,EntitySet)
                linesset |=set(EntitySet)
            else :linesset.add(line)
      pts,graph=MakeGraph(linesset)




@command()
def graph1(doc):
    global pts,graph
    with dbtrans(doc) as tr:
      btr=tr.opencurrspace()
      if pts==graph==None:
            res= ssget(":A",(conv.And,(0,"*line"),(8,r'CableTray-Center')))
            if not res.ok() : return
            ids =tuple(res)
            lines = map(lambda x: x.GetObject(OpenMode.ForWrite),ids)
            linesset =set()
            for line in lines:
                if isinstance(line, (Polyline,Polyline3d,Polyline2d)):
                  EntitySet=DBObjectCollection()
                  Entity.Explode(line,EntitySet)
                  linesset |=set(EntitySet)
                else :linesset.add(line)
            pts,graph=MakeGraph(linesset)
      start = round_point(edx.getpoint("请输入起点").value)
      start=ColsestPointByPt(start,pts)
      colorindex=256
      end = round_point(edx.getpoint("请输入终点").value)
      end=ColsestPointByPt(end,pts)
      try:
            path_reasult=graph.a_star_search(start, end)
            newpath=AddPolyLine(path_reasult,colorindex)
            newpath.Layer="E-PL"
            tr.addentity(btr,newpath)
      except KeyError:
            print("未找{0}到{1}的路径\n".format(str(start),str(end)))






闻人南131 发表于 2023-6-3 18:28:45

厉害厉害厉害&#128077;&#127995;

gzxl 发表于 2023-6-4 07:41:26

很好奇 速度是咋样的。

枫叶棋语 发表于 2023-6-4 20:10:33

gzxl 发表于 2023-6-4 07:41
很好奇 速度是咋样的。

速度还可以,如果图形十分复杂,还是建议用其他方法

dcl1214 发表于 2023-6-4 23:50:42

我已经用lisp实现了迪杰斯特拉算法,几十万条路径计算速度很快

橡皮 发表于 2023-6-8 16:04:56

dcl1214 发表于 2023-6-4 23:50
我已经用lisp实现了迪杰斯特拉算法,几十万条路径计算速度很快

兄弟狠人啊,啥时候发个动图瞧一瞧.
页: [1]
查看完整版本: Pycad 四叉树配合A*算法寻路