# mypy: allow-untyped-defs from __future__ import annotations import ast from bisect import bisect_right from collections.abc import Iterable from collections.abc import Iterator import inspect import textwrap import tokenize import types from typing import overload import warnings class Source: """An immutable object holding a source code fragment. When using Source(...), the source lines are deindented. """ def __init__(self, obj: object = None) -> None: if not obj: self.lines: list[str] = [] self.raw_lines: list[str] = [] elif isinstance(obj, Source): self.lines = obj.lines self.raw_lines = obj.raw_lines elif isinstance(obj, (tuple, list)): self.lines = deindent(x.rstrip("\n") for x in obj) self.raw_lines = list(x.rstrip("\n") for x in obj) elif isinstance(obj, str): self.lines = deindent(obj.split("\n")) self.raw_lines = obj.split("\n") else: try: rawcode = getrawcode(obj) src = inspect.getsource(rawcode) except TypeError: src = inspect.getsource(obj) # type: ignore[arg-type] self.lines = deindent(src.split("\n")) self.raw_lines = src.split("\n") def __eq__(self, other: object) -> bool: if not isinstance(other, Source): return NotImplemented return self.lines == other.lines # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore @overload def __getitem__(self, key: int) -> str: ... @overload def __getitem__(self, key: slice) -> Source: ... def __getitem__(self, key: int | slice) -> str | Source: if isinstance(key, int): return self.lines[key] else: if key.step not in (None, 1): raise IndexError("cannot slice a Source with a step") newsource = Source() newsource.lines = self.lines[key.start : key.stop] newsource.raw_lines = self.raw_lines[key.start : key.stop] return newsource def __iter__(self) -> Iterator[str]: return iter(self.lines) def __len__(self) -> int: return len(self.lines) def strip(self) -> Source: """Return new Source object with trailing and leading blank lines removed.""" start, end = 0, len(self) while start < end and not self.lines[start].strip(): start += 1 while end > start and not self.lines[end - 1].strip(): end -= 1 source = Source() source.raw_lines = self.raw_lines source.lines[:] = self.lines[start:end] return source def indent(self, indent: str = " " * 4) -> Source: """Return a copy of the source object with all lines indented by the given indent-string.""" newsource = Source() newsource.raw_lines = self.raw_lines newsource.lines = [(indent + line) for line in self.lines] return newsource def getstatement(self, lineno: int) -> Source: """Return Source statement which contains the given linenumber (counted from 0).""" start, end = self.getstatementrange(lineno) return self[start:end] def getstatementrange(self, lineno: int) -> tuple[int, int]: """Return (start, end) tuple which spans the minimal statement region which containing the given lineno.""" if not (0 <= lineno < len(self)): raise IndexError("lineno out of range") ast, start, end = getstatementrange_ast(lineno, self) return start, end def deindent(self) -> Source: """Return a new Source object deindented.""" newsource = Source() newsource.lines[:] = deindent(self.lines) newsource.raw_lines = self.raw_lines return newsource def __str__(self) -> str: return "\n".join(self.lines) # # helper functions # def findsource(obj) -> tuple[Source | None, int]: try: sourcelines, lineno = inspect.findsource(obj) except Exception: return None, -1 source = Source() source.lines = [line.rstrip() for line in sourcelines] source.raw_lines = sourcelines return source, lineno def getrawcode(obj: object, trycall: bool = True) -> types.CodeType: """Return code object for given function.""" try: return obj.__code__ # type: ignore[attr-defined,no-any-return] except AttributeError: pass if trycall: call = getattr(obj, "__call__", None) if call and not isinstance(obj, type): return getrawcode(call, trycall=False) raise TypeError(f"could not get code object for {obj!r}") def deindent(lines: Iterable[str]) -> list[str]: return textwrap.dedent("\n".join(lines)).splitlines() def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]: # Flatten all statements and except handlers into one lineno-list. # AST's line numbers start indexing at 1. values: list[int] = [] for x in ast.walk(node): if isinstance(x, (ast.stmt, ast.ExceptHandler)): # The lineno points to the class/def, so need to include the decorators. if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): for d in x.decorator_list: values.append(d.lineno - 1) values.append(x.lineno - 1) for name in ("finalbody", "orelse"): val: list[ast.stmt] | None = getattr(x, name, None) if val: # Treat the finally/orelse part as its own statement. values.append(val[0].lineno - 1 - 1) values.sort() insert_index = bisect_right(values, lineno) start = values[insert_index - 1] if insert_index >= len(values): end = None else: end = values[insert_index] return start, end def getstatementrange_ast( lineno: int, source: Source, assertion: bool = False, astnode: ast.AST | None = None, ) -> tuple[ast.AST, int, int]: if astnode is None: content = str(source) # See #4260: # Don't produce duplicate warnings when compiling source to find AST. with warnings.catch_warnings(): warnings.simplefilter("ignore") astnode = ast.parse(content, "source", "exec") start, end = get_statement_startend2(lineno, astnode) # We need to correct the end: # - ast-parsing strips comments # - there might be empty lines # - we might have lesser indented code blocks at the end if end is None: end = len(source.lines) if end > start + 1: # Make sure we don't span differently indented code blocks # by using the BlockFinder helper used which inspect.getsource() uses itself. block_finder = inspect.BlockFinder() # If we start with an indented line, put blockfinder to "started" mode. block_finder.started = ( bool(source.lines[start]) and source.lines[start][0].isspace() ) it = ((x + "\n") for x in source.lines[start:end]) try: for tok in tokenize.generate_tokens(lambda: next(it)): block_finder.tokeneater(*tok) except (inspect.EndOfBlock, IndentationError): end = block_finder.last + start except Exception: pass # The end might still point to a comment or empty line, correct it. while end: line = source.lines[end - 1].lstrip() if line.startswith("#") or not line: end -= 1 else: break return astnode, start, end