@@ -99,7 +99,10 @@ class RepoRevisionHelper:
9999 """
100100
101101 def __init__ (
102- self , target_revision : Optional [str ], logger : Optional [Logger ] = None
102+ self ,
103+ target_revision : Optional [str ],
104+ logger : Optional [Logger ] = None ,
105+ discard_local_changes : bool = False ,
103106 ) -> None :
104107 """
105108 Args:
@@ -108,12 +111,14 @@ def __init__(
108111 self ._target_revision : Optional [str ] = target_revision
109112 self ._prev_revision : Optional [str ] = None
110113 self ._logger : Logger = logger or LOGGER
114+ self ._discard_local_changes = discard_local_changes
111115
112116 async def __aenter__ (self ) -> None :
113117 if self ._target_revision is None :
114118 # we only maintain the context when a target revision is provided
115119 return
116- self .shelve () # shelve the uncommited changes
120+ if not self ._discard_local_changes :
121+ self .shelve () # shelve the uncommited changes
117122 self ._prev_revision = self .get_repo_revision ()
118123 self ._target_revision : str
119124 self .checkout_repo_revision (self ._target_revision )
@@ -129,7 +134,8 @@ async def __aexit__(
129134
130135 if self ._prev_revision :
131136 self .checkout_repo_revision (self ._prev_revision )
132- self .unshelve ()
137+ if not self ._discard_local_changes :
138+ self .unshelve ()
133139 return False # Do not suppress exceptions raised in the context
134140
135141 def checkout_repo_revision (self , repo_hash : str ) -> None :
0 commit comments